diff --git a/crates/simd/build.rs b/crates/simd/build.rs index 63af110d..338f9e17 100644 --- a/crates/simd/build.rs +++ b/crates/simd/build.rs @@ -23,7 +23,6 @@ fn main() -> Result<(), Box> { match target_arch.as_str() { "aarch64" => { let mut build = cc::Build::new(); - build.file("./cshim/aarch64_fp16.c"); build.file("./cshim/aarch64_byte.c"); build.file("./cshim/aarch64_halfbyte.c"); if target_endian == "little" { @@ -33,12 +32,6 @@ fn main() -> Result<(), Box> { build.opt_level(3); build.compile("simd_cshim"); } - "x86_64" => { - let mut build = cc::Build::new(); - build.file("./cshim/x86_64_fp16.c"); - build.opt_level(3); - build.compile("simd_cshim"); - } _ => {} } Ok(()) diff --git a/crates/simd/cshim/aarch64_fp16.c b/crates/simd/cshim/aarch64_fp16.c deleted file mode 100644 index ad7ffb93..00000000 --- a/crates/simd/cshim/aarch64_fp16.c +++ /dev/null @@ -1,250 +0,0 @@ -// This software is licensed under a dual license model: -// -// GNU Affero General Public License v3 (AGPLv3): You may use, modify, and -// distribute this software under the terms of the AGPLv3. -// -// Elastic License v2 (ELv2): You may also use, modify, and distribute this -// software under the Elastic License v2, which has specific restrictions. -// -// We welcome any commercial collaboration or support. For inquiries -// regarding the licenses, please contact us at: -// vectorchord-inquiry@tensorchord.ai -// -// Copyright (c) 2025 TensorChord Inc. - -#if defined(__clang__) -#if !(__clang_major__ >= 16) -#error "Clang version must be at least 16." -#endif -#elif defined(__GNUC__) -#if !(__GNUC__ >= 14) -#error "GCC version must be at least 14." -#endif -#else -#error "This file requires Clang or GCC." -#endif - -#include -#include -#include - -typedef __fp16 f16; -typedef float f32; - -__attribute__((target("+fp16"))) float -fp16_reduce_sum_of_xy_a2_fp16(size_t n, f16 *restrict a, f16 *restrict b) { - float16x8_t sum_0 = vdupq_n_f16(0.0); - float16x8_t sum_1 = vdupq_n_f16(0.0); - float16x8_t sum_2 = vdupq_n_f16(0.0); - float16x8_t sum_3 = vdupq_n_f16(0.0); - float16x8_t sum_4 = vdupq_n_f16(0.0); - float16x8_t sum_5 = vdupq_n_f16(0.0); - float16x8_t sum_6 = vdupq_n_f16(0.0); - float16x8_t sum_7 = vdupq_n_f16(0.0); - while (n >= 64) { - float16x8_t x_0 = vld1q_f16(a + 0); - float16x8_t x_1 = vld1q_f16(a + 8); - float16x8_t x_2 = vld1q_f16(a + 16); - float16x8_t x_3 = vld1q_f16(a + 24); - float16x8_t x_4 = vld1q_f16(a + 32); - float16x8_t x_5 = vld1q_f16(a + 40); - float16x8_t x_6 = vld1q_f16(a + 48); - float16x8_t x_7 = vld1q_f16(a + 56); - float16x8_t y_0 = vld1q_f16(b + 0); - float16x8_t y_1 = vld1q_f16(b + 8); - float16x8_t y_2 = vld1q_f16(b + 16); - float16x8_t y_3 = vld1q_f16(b + 24); - float16x8_t y_4 = vld1q_f16(b + 32); - float16x8_t y_5 = vld1q_f16(b + 40); - float16x8_t y_6 = vld1q_f16(b + 48); - float16x8_t y_7 = vld1q_f16(b + 56); - sum_0 = vfmaq_f16(sum_0, x_0, y_0); - sum_1 = vfmaq_f16(sum_1, x_1, y_1); - sum_2 = vfmaq_f16(sum_2, x_2, y_2); - sum_3 = vfmaq_f16(sum_3, x_3, y_3); - sum_4 = vfmaq_f16(sum_4, x_4, y_4); - sum_5 = vfmaq_f16(sum_5, x_5, y_5); - sum_6 = vfmaq_f16(sum_6, x_6, y_6); - sum_7 = vfmaq_f16(sum_7, x_7, y_7); - n -= 64, a += 64, b += 64; - } - if (n >= 32) { - float16x8_t x_0 = vld1q_f16(a + 0); - float16x8_t x_1 = vld1q_f16(a + 8); - float16x8_t x_2 = vld1q_f16(a + 16); - float16x8_t x_3 = vld1q_f16(a + 24); - float16x8_t y_0 = vld1q_f16(b + 0); - float16x8_t y_1 = vld1q_f16(b + 8); - float16x8_t y_2 = vld1q_f16(b + 16); - float16x8_t y_3 = vld1q_f16(b + 24); - sum_0 = vfmaq_f16(sum_0, x_0, y_0); - sum_1 = vfmaq_f16(sum_1, x_1, y_1); - sum_2 = vfmaq_f16(sum_2, x_2, y_2); - sum_3 = vfmaq_f16(sum_3, x_3, y_3); - n -= 32, a += 32, b += 32; - } - if (n >= 16) { - float16x8_t x_4 = vld1q_f16(a + 0); - float16x8_t x_5 = vld1q_f16(a + 8); - float16x8_t y_4 = vld1q_f16(b + 0); - float16x8_t y_5 = vld1q_f16(b + 8); - sum_4 = vfmaq_f16(sum_4, x_4, y_4); - sum_5 = vfmaq_f16(sum_5, x_5, y_5); - n -= 16, a += 16, b += 16; - } - if (n >= 8) { - float16x8_t x_6 = vld1q_f16(a + 0); - float16x8_t y_6 = vld1q_f16(b + 0); - sum_6 = vfmaq_f16(sum_6, x_6, y_6); - n -= 8, a += 8, b += 8; - } - if (n > 0) { - f16 _a[8] = {}, _b[8] = {}; - for (size_t i = 0; i < n; i += 1) { - _a[i] = a[i], _b[i] = b[i]; - } - a = _a, b = _b; - float16x8_t x_7 = vld1q_f16(a); - float16x8_t y_7 = vld1q_f16(b); - sum_7 = vfmaq_f16(sum_7, x_7, y_7); - } - float32x4_t s_0 = vcvt_f32_f16(vget_low_f16(sum_0)); - float32x4_t s_1 = vcvt_f32_f16(vget_high_f16(sum_0)); - float32x4_t s_2 = vcvt_f32_f16(vget_low_f16(sum_1)); - float32x4_t s_3 = vcvt_f32_f16(vget_high_f16(sum_1)); - float32x4_t s_4 = vcvt_f32_f16(vget_low_f16(sum_2)); - float32x4_t s_5 = vcvt_f32_f16(vget_high_f16(sum_2)); - float32x4_t s_6 = vcvt_f32_f16(vget_low_f16(sum_3)); - float32x4_t s_7 = vcvt_f32_f16(vget_high_f16(sum_3)); - float32x4_t s_8 = vcvt_f32_f16(vget_low_f16(sum_4)); - float32x4_t s_9 = vcvt_f32_f16(vget_high_f16(sum_4)); - float32x4_t s_a = vcvt_f32_f16(vget_low_f16(sum_5)); - float32x4_t s_b = vcvt_f32_f16(vget_high_f16(sum_5)); - float32x4_t s_c = vcvt_f32_f16(vget_low_f16(sum_6)); - float32x4_t s_d = vcvt_f32_f16(vget_high_f16(sum_6)); - float32x4_t s_e = vcvt_f32_f16(vget_low_f16(sum_7)); - float32x4_t s_f = vcvt_f32_f16(vget_high_f16(sum_7)); - float32x4_t s = vpaddq_f32( - vpaddq_f32(vpaddq_f32(vpaddq_f32(s_0, s_1), vpaddq_f32(s_2, s_3)), - vpaddq_f32(vpaddq_f32(s_4, s_5), vpaddq_f32(s_6, s_7))), - vpaddq_f32(vpaddq_f32(vpaddq_f32(s_8, s_9), vpaddq_f32(s_a, s_b)), - vpaddq_f32(vpaddq_f32(s_c, s_d), vpaddq_f32(s_e, s_f)))); - return vaddvq_f32(s); -} - -__attribute__((target("+fp16"))) float -fp16_reduce_sum_of_d2_a2_fp16(size_t n, f16 *restrict a, f16 *restrict b) { - float16x8_t sum_0 = vdupq_n_f16(0.0); - float16x8_t sum_1 = vdupq_n_f16(0.0); - float16x8_t sum_2 = vdupq_n_f16(0.0); - float16x8_t sum_3 = vdupq_n_f16(0.0); - float16x8_t sum_4 = vdupq_n_f16(0.0); - float16x8_t sum_5 = vdupq_n_f16(0.0); - float16x8_t sum_6 = vdupq_n_f16(0.0); - float16x8_t sum_7 = vdupq_n_f16(0.0); - while (n >= 64) { - float16x8_t x_0 = vld1q_f16(a + 0); - float16x8_t x_1 = vld1q_f16(a + 8); - float16x8_t x_2 = vld1q_f16(a + 16); - float16x8_t x_3 = vld1q_f16(a + 24); - float16x8_t x_4 = vld1q_f16(a + 32); - float16x8_t x_5 = vld1q_f16(a + 40); - float16x8_t x_6 = vld1q_f16(a + 48); - float16x8_t x_7 = vld1q_f16(a + 56); - float16x8_t y_0 = vld1q_f16(b + 0); - float16x8_t y_1 = vld1q_f16(b + 8); - float16x8_t y_2 = vld1q_f16(b + 16); - float16x8_t y_3 = vld1q_f16(b + 24); - float16x8_t y_4 = vld1q_f16(b + 32); - float16x8_t y_5 = vld1q_f16(b + 40); - float16x8_t y_6 = vld1q_f16(b + 48); - float16x8_t y_7 = vld1q_f16(b + 56); - float16x8_t d_0 = vsubq_f16(x_0, y_0); - float16x8_t d_1 = vsubq_f16(x_1, y_1); - float16x8_t d_2 = vsubq_f16(x_2, y_2); - float16x8_t d_3 = vsubq_f16(x_3, y_3); - float16x8_t d_4 = vsubq_f16(x_4, y_4); - float16x8_t d_5 = vsubq_f16(x_5, y_5); - float16x8_t d_6 = vsubq_f16(x_6, y_6); - float16x8_t d_7 = vsubq_f16(x_7, y_7); - sum_0 = vfmaq_f16(sum_0, d_0, d_0); - sum_1 = vfmaq_f16(sum_1, d_1, d_1); - sum_2 = vfmaq_f16(sum_2, d_2, d_2); - sum_3 = vfmaq_f16(sum_3, d_3, d_3); - sum_4 = vfmaq_f16(sum_4, d_4, d_4); - sum_5 = vfmaq_f16(sum_5, d_5, d_5); - sum_6 = vfmaq_f16(sum_6, d_6, d_6); - sum_7 = vfmaq_f16(sum_7, d_7, d_7); - n -= 64, a += 64, b += 64; - } - if (n >= 32) { - float16x8_t x_0 = vld1q_f16(a + 0); - float16x8_t x_1 = vld1q_f16(a + 8); - float16x8_t x_2 = vld1q_f16(a + 16); - float16x8_t x_3 = vld1q_f16(a + 24); - float16x8_t y_0 = vld1q_f16(b + 0); - float16x8_t y_1 = vld1q_f16(b + 8); - float16x8_t y_2 = vld1q_f16(b + 16); - float16x8_t y_3 = vld1q_f16(b + 24); - float16x8_t d_0 = vsubq_f16(x_0, y_0); - float16x8_t d_1 = vsubq_f16(x_1, y_1); - float16x8_t d_2 = vsubq_f16(x_2, y_2); - float16x8_t d_3 = vsubq_f16(x_3, y_3); - sum_0 = vfmaq_f16(sum_0, d_0, d_0); - sum_1 = vfmaq_f16(sum_1, d_1, d_1); - sum_2 = vfmaq_f16(sum_2, d_2, d_2); - sum_3 = vfmaq_f16(sum_3, d_3, d_3); - n -= 32, a += 32, b += 32; - } - if (n >= 16) { - float16x8_t x_4 = vld1q_f16(a + 0); - float16x8_t x_5 = vld1q_f16(a + 8); - float16x8_t y_4 = vld1q_f16(b + 0); - float16x8_t y_5 = vld1q_f16(b + 8); - float16x8_t d_4 = vsubq_f16(x_4, y_4); - float16x8_t d_5 = vsubq_f16(x_5, y_5); - sum_4 = vfmaq_f16(sum_4, d_4, d_4); - sum_5 = vfmaq_f16(sum_5, d_5, d_5); - n -= 16, a += 16, b += 16; - } - if (n >= 8) { - float16x8_t x_6 = vld1q_f16(a + 0); - float16x8_t y_6 = vld1q_f16(b + 0); - float16x8_t d_6 = vsubq_f16(x_6, y_6); - sum_6 = vfmaq_f16(sum_6, d_6, d_6); - n -= 8, a += 8, b += 8; - } - if (n > 0) { - f16 _a[8] = {}, _b[8] = {}; - for (size_t i = 0; i < n; i += 1) { - _a[i] = a[i], _b[i] = b[i]; - } - a = _a, b = _b; - float16x8_t x_7 = vld1q_f16(a); - float16x8_t y_7 = vld1q_f16(b); - float16x8_t d_7 = vsubq_f16(x_7, y_7); - sum_7 = vfmaq_f16(sum_7, d_7, d_7); - } - float32x4_t s_0 = vcvt_f32_f16(vget_low_f16(sum_0)); - float32x4_t s_1 = vcvt_f32_f16(vget_high_f16(sum_0)); - float32x4_t s_2 = vcvt_f32_f16(vget_low_f16(sum_1)); - float32x4_t s_3 = vcvt_f32_f16(vget_high_f16(sum_1)); - float32x4_t s_4 = vcvt_f32_f16(vget_low_f16(sum_2)); - float32x4_t s_5 = vcvt_f32_f16(vget_high_f16(sum_2)); - float32x4_t s_6 = vcvt_f32_f16(vget_low_f16(sum_3)); - float32x4_t s_7 = vcvt_f32_f16(vget_high_f16(sum_3)); - float32x4_t s_8 = vcvt_f32_f16(vget_low_f16(sum_4)); - float32x4_t s_9 = vcvt_f32_f16(vget_high_f16(sum_4)); - float32x4_t s_a = vcvt_f32_f16(vget_low_f16(sum_5)); - float32x4_t s_b = vcvt_f32_f16(vget_high_f16(sum_5)); - float32x4_t s_c = vcvt_f32_f16(vget_low_f16(sum_6)); - float32x4_t s_d = vcvt_f32_f16(vget_high_f16(sum_6)); - float32x4_t s_e = vcvt_f32_f16(vget_low_f16(sum_7)); - float32x4_t s_f = vcvt_f32_f16(vget_high_f16(sum_7)); - float32x4_t s = vpaddq_f32( - vpaddq_f32(vpaddq_f32(vpaddq_f32(s_0, s_1), vpaddq_f32(s_2, s_3)), - vpaddq_f32(vpaddq_f32(s_4, s_5), vpaddq_f32(s_6, s_7))), - vpaddq_f32(vpaddq_f32(vpaddq_f32(s_8, s_9), vpaddq_f32(s_a, s_b)), - vpaddq_f32(vpaddq_f32(s_c, s_d), vpaddq_f32(s_e, s_f)))); - return vaddvq_f32(s); -} diff --git a/crates/simd/cshim/x86_64_fp16.c b/crates/simd/cshim/x86_64_fp16.c deleted file mode 100644 index 79a204b5..00000000 --- a/crates/simd/cshim/x86_64_fp16.c +++ /dev/null @@ -1,141 +0,0 @@ -// This software is licensed under a dual license model: -// -// GNU Affero General Public License v3 (AGPLv3): You may use, modify, and -// distribute this software under the terms of the AGPLv3. -// -// Elastic License v2 (ELv2): You may also use, modify, and distribute this -// software under the Elastic License v2, which has specific restrictions. -// -// We welcome any commercial collaboration or support. For inquiries -// regarding the licenses, please contact us at: -// vectorchord-inquiry@tensorchord.ai -// -// Copyright (c) 2025 TensorChord Inc. - -#if defined(__clang__) -#if !(__clang_major__ >= 16) -#error "Clang version must be at least 16." -#endif -#elif defined(__GNUC__) -#if !(__GNUC__ >= 12) -#error "GCC version must be at least 12." -#endif -#else -#error "This file requires Clang or GCC." -#endif - -#include -#include -#include - -#if defined(__clang__) && defined(_MSC_VER) && (__clang_major__ <= 19) -// https://github.com/llvm/llvm-project/issues/53520 -// clang-format off -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -// clang-format on -#endif - -typedef _Float16 f16; -typedef float f32; - -__attribute__((target("avx512bw,avx512cd,avx512dq,avx512vl,bmi,bmi2,lzcnt," - "movbe,popcnt,avx512fp16"))) float -fp16_reduce_sum_of_xy_v4_avx512fp16(size_t n, f16 *restrict a, - f16 *restrict b) { - __m512h _0 = _mm512_setzero_ph(); - __m512h _1 = _mm512_setzero_ph(); - while (n >= 64) { - __m512h x_0 = _mm512_loadu_ph(a + 0); - __m512h x_1 = _mm512_loadu_ph(a + 32); - __m512h y_0 = _mm512_loadu_ph(b + 0); - __m512h y_1 = _mm512_loadu_ph(b + 32); - _0 = _mm512_fmadd_ph(x_0, y_0, _0); - _1 = _mm512_fmadd_ph(x_1, y_1, _1); - n -= 64, a += 64, b += 64; - } - while (n >= 32) { - __m512h x_0 = _mm512_loadu_ph(a + 0); - __m512h y_0 = _mm512_loadu_ph(b + 0); - _0 = _mm512_fmadd_ph(x_0, y_0, _0); - n -= 32, a += 32, b += 32; - } - if (n > 0) { - unsigned int mask = _bzhi_u32(0xffffffff, n); - __m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); - __m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); - _1 = _mm512_fmadd_ph(x, y, _1); - } - __m512 s_0 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 0)); - __m512 s_1 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 1)); - __m512 s_2 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 0)); - __m512 s_3 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 1)); - return _mm512_reduce_add_ps( - _mm512_add_ps(_mm512_add_ps(s_0, s_2), _mm512_add_ps(s_1, s_3))); -} - -__attribute__((target("avx512bw,avx512cd,avx512dq,avx512vl,bmi,bmi2,lzcnt," - "movbe,popcnt,avx512fp16"))) float -fp16_reduce_sum_of_d2_v4_avx512fp16(size_t n, f16 *restrict a, - f16 *restrict b) { - __m512h _0 = _mm512_setzero_ph(); - __m512h _1 = _mm512_setzero_ph(); - while (n >= 64) { - __m512h x_0 = _mm512_loadu_ph(a + 0); - __m512h x_1 = _mm512_loadu_ph(a + 32); - __m512h y_0 = _mm512_loadu_ph(b + 0); - __m512h y_1 = _mm512_loadu_ph(b + 32); - __m512h d_0 = _mm512_sub_ph(x_0, y_0); - __m512h d_1 = _mm512_sub_ph(x_1, y_1); - _0 = _mm512_fmadd_ph(d_0, d_0, _0); - _1 = _mm512_fmadd_ph(d_1, d_1, _1); - n -= 64, a += 64, b += 64; - } - while (n >= 32) { - __m512h x_0 = _mm512_loadu_ph(a + 0); - __m512h y_0 = _mm512_loadu_ph(b + 0); - __m512h d_0 = _mm512_sub_ph(x_0, y_0); - _0 = _mm512_fmadd_ph(d_0, d_0, _0); - n -= 32, a += 32, b += 32; - } - if (n > 0) { - unsigned int mask = _bzhi_u32(0xffffffff, n); - __m512h x_1 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); - __m512h y_1 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); - __m512h d_1 = _mm512_sub_ph(x_1, y_1); - _1 = _mm512_fmadd_ph(d_1, d_1, _1); - } - __m512 s_0 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 0)); - __m512 s_1 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 1)); - __m512 s_2 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 0)); - __m512 s_3 = - _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 1)); - return _mm512_reduce_add_ps( - _mm512_add_ps(_mm512_add_ps(s_0, s_2), _mm512_add_ps(s_1, s_3))); -} diff --git a/crates/simd/src/emulate.rs b/crates/simd/src/emulate.rs index b84e5e8f..e710ab6a 100644 --- a/crates/simd/src/emulate.rs +++ b/crates/simd/src/emulate.rs @@ -209,3 +209,21 @@ pub fn emulate_mm256_reduce_add_epi64(mut x: core::arch::x86_64::__m256i) -> i64 x = _mm256_add_epi64(x, _mm256_permute2f128_si256(x, x, 1)); _mm256_extract_epi64(x, 0) + _mm256_extract_epi64(x, 1) } + +#[inline] +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub fn emulate_vreinterpret_f16_u16( + x: core::arch::aarch64::uint16x4_t, +) -> core::arch::aarch64::float16x4_t { + unsafe { core::mem::transmute(x) } +} + +#[inline] +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub fn emulate_vreinterpretq_f16_u16( + x: core::arch::aarch64::uint16x8_t, +) -> core::arch::aarch64::float16x8_t { + unsafe { core::mem::transmute(x) } +} diff --git a/crates/simd/src/floating_f16.rs b/crates/simd/src/floating_f16.rs index e7b0d7ea..9438f870 100644 --- a/crates/simd/src/floating_f16.rs +++ b/crates/simd/src/floating_f16.rs @@ -165,6 +165,71 @@ mod reduce_or_of_is_zero_x { mod reduce_sum_of_x { use crate::{F16, f16}; + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + fn reduce_sum_of_x_v4_avx512fp16(this: &[f16]) -> f32 { + use core::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut _0 = _mm512_setzero_ph(); + let mut _1 = _mm512_setzero_ph(); + while n >= 64 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(32).cast())) }; + _0 = _mm512_add_ph(_0, x_0); + _1 = _mm512_add_ph(_1, x_1); + (n, a) = unsafe { (n - 64, a.add(64)) }; + } + if n >= 32 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + _0 = _mm512_add_ph(_0, x_0); + (n, a) = unsafe { (n - 32, a.add(32)) }; + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())) }; + _1 = _mm512_add_ph(_1, x_1); + } + let s_0 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 0)); + let s_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 1)); + let s_2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 0)); + let s_3 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 1)); + _mm512_reduce_add_ps(_mm512_add_ps( + _mm512_add_ps(s_0, s_2), + _mm512_add_ps(s_1, s_3), + )) + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_x_v4_avx512fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v4_avx512fp16(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v4")] @@ -262,8 +327,187 @@ mod reduce_sum_of_x { } } + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + #[target_feature(enable = "fp16")] + fn reduce_sum_of_x_a2_fp16(this: &[f16]) -> f32 { + use crate::emulate::partial_load; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum_0 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_1 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_2 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_3 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_4 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_5 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_6 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_7 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + while n >= 64 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(32).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(40).cast()) }); + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(48).cast()) }); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(56).cast()) }); + sum_0 = vaddq_f16(sum_0, x_0); + sum_1 = vaddq_f16(sum_1, x_1); + sum_2 = vaddq_f16(sum_2, x_2); + sum_3 = vaddq_f16(sum_3, x_3); + sum_4 = vaddq_f16(sum_4, x_4); + sum_5 = vaddq_f16(sum_5, x_5); + sum_6 = vaddq_f16(sum_6, x_6); + sum_7 = vaddq_f16(sum_7, x_7); + (n, a) = unsafe { (n - 64, a.add(64)) }; + } + if n >= 32 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + sum_0 = vaddq_f16(sum_0, x_0); + sum_1 = vaddq_f16(sum_1, x_1); + sum_2 = vaddq_f16(sum_2, x_2); + sum_3 = vaddq_f16(sum_3, x_3); + (n, a) = unsafe { (n - 32, a.add(32)) }; + } + if n >= 16 { + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + sum_4 = vaddq_f16(sum_4, x_4); + sum_5 = vaddq_f16(sum_5, x_5); + (n, a) = unsafe { (n - 16, a.add(16)) }; + } + if n >= 8 { + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + sum_6 = vaddq_f16(sum_6, x_6); + (n, a) = unsafe { (n - 8, a.add(8)) }; + } + if n > 0 { + let (_a,) = unsafe { partial_load!(8, n, a) }; + (a,) = (_a.as_ptr(),); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.cast()) }); + sum_7 = vaddq_f16(sum_7, x_7); + } + let s_0 = vcvt_f32_f16(vget_low_f16(sum_0)); + let s_1 = vcvt_f32_f16(vget_high_f16(sum_0)); + let s_2 = vcvt_f32_f16(vget_low_f16(sum_1)); + let s_3 = vcvt_f32_f16(vget_high_f16(sum_1)); + let s_4 = vcvt_f32_f16(vget_low_f16(sum_2)); + let s_5 = vcvt_f32_f16(vget_high_f16(sum_2)); + let s_6 = vcvt_f32_f16(vget_low_f16(sum_3)); + let s_7 = vcvt_f32_f16(vget_high_f16(sum_3)); + let s_8 = vcvt_f32_f16(vget_low_f16(sum_4)); + let s_9 = vcvt_f32_f16(vget_high_f16(sum_4)); + let s_a = vcvt_f32_f16(vget_low_f16(sum_5)); + let s_b = vcvt_f32_f16(vget_high_f16(sum_5)); + let s_c = vcvt_f32_f16(vget_low_f16(sum_6)); + let s_d = vcvt_f32_f16(vget_high_f16(sum_6)); + let s_e = vcvt_f32_f16(vget_low_f16(sum_7)); + let s_f = vcvt_f32_f16(vget_high_f16(sum_7)); + let s = vpaddq_f32( + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_0, s_1), vpaddq_f32(s_2, s_3)), + vpaddq_f32(vpaddq_f32(s_4, s_5), vpaddq_f32(s_6, s_7)), + ), + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_8, s_9), vpaddq_f32(s_a, s_b)), + vpaddq_f32(vpaddq_f32(s_c, s_d), vpaddq_f32(s_e, s_f)), + ), + ); + vaddvq_f32(s) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_x_a2_fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("a2") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (a2:fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_a2_fp16(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_x_a2(this: &[f16]) -> f32 { + use crate::emulate::emulate_vreinterpret_f16_u16; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + sum = vaddq_f32(sum, x); + (n, a) = unsafe { (n - 4, a.add(4)) }; + } + if n > 0 { + let mut _a = [f16::_ZERO; 4]; + let mut _b = [f16::_ZERO; 4]; + unsafe { + std::ptr::copy_nonoverlapping(a.cast(), _a.as_mut_ptr(), n); + } + (a,) = (_a.as_ptr(),); + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + sum = vaddq_f32(sum, x); + } + vaddvq_f32(sum) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_x_a2_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_a2(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + #[crate::multiversion( - @"v4", @"v3", "v2", "a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" + @"v4:avx512fp16", @"v4", @"v3", "v2", @"a2:fp16", @"a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" )] pub fn reduce_sum_of_x(this: &[f16]) -> f32 { let n = this.len(); @@ -278,6 +522,71 @@ mod reduce_sum_of_x { mod reduce_sum_of_abs_x { use crate::{F16, f16}; + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + fn reduce_sum_of_abs_x_v4_avx512fp16(this: &[f16]) -> f32 { + use core::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut _0 = _mm512_setzero_ph(); + let mut _1 = _mm512_setzero_ph(); + while n >= 64 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(32).cast())) }; + _0 = _mm512_add_ph(_0, _mm512_abs_ph(x_0)); + _1 = _mm512_add_ph(_1, _mm512_abs_ph(x_1)); + (n, a) = unsafe { (n - 64, a.add(64)) }; + } + while n >= 32 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + _0 = _mm512_add_ph(_0, _mm512_abs_ph(x_0)); + (n, a) = unsafe { (n - 32, a.add(32)) }; + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())) }; + _1 = _mm512_add_ph(_1, _mm512_abs_ph(x_1)); + } + let s_0 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 0)); + let s_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 1)); + let s_2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 0)); + let s_3 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 1)); + _mm512_reduce_add_ps(_mm512_add_ps( + _mm512_add_ps(s_0, s_2), + _mm512_add_ps(s_1, s_3), + )) + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_abs_x_v4_avx512fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v4_avx512fp16(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v4")] @@ -376,8 +685,187 @@ mod reduce_sum_of_abs_x { } } + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + #[target_feature(enable = "fp16")] + fn reduce_sum_of_abs_x_a2_fp16(this: &[f16]) -> f32 { + use crate::emulate::partial_load; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum_0 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_1 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_2 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_3 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_4 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_5 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_6 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_7 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + while n >= 64 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(32).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(40).cast()) }); + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(48).cast()) }); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(56).cast()) }); + sum_0 = vaddq_f16(sum_0, vabsq_f16(x_0)); + sum_1 = vaddq_f16(sum_1, vabsq_f16(x_1)); + sum_2 = vaddq_f16(sum_2, vabsq_f16(x_2)); + sum_3 = vaddq_f16(sum_3, vabsq_f16(x_3)); + sum_4 = vaddq_f16(sum_4, vabsq_f16(x_4)); + sum_5 = vaddq_f16(sum_5, vabsq_f16(x_5)); + sum_6 = vaddq_f16(sum_6, vabsq_f16(x_6)); + sum_7 = vaddq_f16(sum_7, vabsq_f16(x_7)); + (n, a) = unsafe { (n - 64, a.add(64)) }; + } + if n >= 32 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + sum_0 = vaddq_f16(sum_0, vabsq_f16(x_0)); + sum_1 = vaddq_f16(sum_1, vabsq_f16(x_1)); + sum_2 = vaddq_f16(sum_2, vabsq_f16(x_2)); + sum_3 = vaddq_f16(sum_3, vabsq_f16(x_3)); + (n, a) = unsafe { (n - 32, a.add(32)) }; + } + if n >= 16 { + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + sum_4 = vaddq_f16(sum_4, vabsq_f16(x_4)); + sum_5 = vaddq_f16(sum_5, vabsq_f16(x_5)); + (n, a) = unsafe { (n - 16, a.add(16)) }; + } + if n >= 8 { + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + sum_6 = vaddq_f16(sum_6, vabsq_f16(x_6)); + (n, a) = unsafe { (n - 8, a.add(8)) }; + } + if n > 0 { + let (_a,) = unsafe { partial_load!(8, n, a) }; + (a,) = (_a.as_ptr(),); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.cast()) }); + sum_7 = vaddq_f16(sum_7, vabsq_f16(x_7)); + } + let s_0 = vcvt_f32_f16(vget_low_f16(sum_0)); + let s_1 = vcvt_f32_f16(vget_high_f16(sum_0)); + let s_2 = vcvt_f32_f16(vget_low_f16(sum_1)); + let s_3 = vcvt_f32_f16(vget_high_f16(sum_1)); + let s_4 = vcvt_f32_f16(vget_low_f16(sum_2)); + let s_5 = vcvt_f32_f16(vget_high_f16(sum_2)); + let s_6 = vcvt_f32_f16(vget_low_f16(sum_3)); + let s_7 = vcvt_f32_f16(vget_high_f16(sum_3)); + let s_8 = vcvt_f32_f16(vget_low_f16(sum_4)); + let s_9 = vcvt_f32_f16(vget_high_f16(sum_4)); + let s_a = vcvt_f32_f16(vget_low_f16(sum_5)); + let s_b = vcvt_f32_f16(vget_high_f16(sum_5)); + let s_c = vcvt_f32_f16(vget_low_f16(sum_6)); + let s_d = vcvt_f32_f16(vget_high_f16(sum_6)); + let s_e = vcvt_f32_f16(vget_low_f16(sum_7)); + let s_f = vcvt_f32_f16(vget_high_f16(sum_7)); + let s = vpaddq_f32( + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_0, s_1), vpaddq_f32(s_2, s_3)), + vpaddq_f32(vpaddq_f32(s_4, s_5), vpaddq_f32(s_6, s_7)), + ), + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_8, s_9), vpaddq_f32(s_a, s_b)), + vpaddq_f32(vpaddq_f32(s_c, s_d), vpaddq_f32(s_e, s_f)), + ), + ); + vaddvq_f32(s) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_abs_x_a2_fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("a2") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (a2:fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_a2_fp16(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_abs_x_a2(this: &[f16]) -> f32 { + use crate::emulate::emulate_vreinterpret_f16_u16; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + sum = vaddq_f32(sum, vabsq_f32(x)); + (n, a) = unsafe { (n - 4, a.add(4)) }; + } + if n > 0 { + let mut _a = [f16::_ZERO; 4]; + let mut _b = [f16::_ZERO; 4]; + unsafe { + std::ptr::copy_nonoverlapping(a.cast(), _a.as_mut_ptr(), n); + } + (a,) = (_a.as_ptr(),); + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + sum = vaddq_f32(sum, vabsq_f32(x)); + } + vaddvq_f32(sum) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_abs_x_a2_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_a2(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + #[crate::multiversion( - @"v4", @"v3", "v2", "a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" + @"v4:avx512fp16", @"v4", @"v3", "v2", @"a2:fp16", @"a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" )] pub fn reduce_sum_of_abs_x(this: &[f16]) -> f32 { let n = this.len(); @@ -392,6 +880,71 @@ mod reduce_sum_of_abs_x { mod reduce_sum_of_x2 { use crate::{F16, f16}; + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + fn reduce_sum_of_x2_v4_avx512fp16(this: &[f16]) -> f32 { + use core::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut _0 = _mm512_setzero_ph(); + let mut _1 = _mm512_setzero_ph(); + while n >= 64 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(32).cast())) }; + _0 = _mm512_fmadd_ph(x_0, x_0, _0); + _1 = _mm512_fmadd_ph(x_1, x_1, _1); + (n, a) = unsafe { (n - 64, a.add(64)) }; + } + if n >= 32 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + _0 = _mm512_fmadd_ph(x_0, x_0, _0); + (n, a) = unsafe { (n - 32, a.add(32)) }; + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())) }; + _1 = _mm512_fmadd_ph(x_1, x_1, _1); + } + let s_0 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 0)); + let s_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 1)); + let s_2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 0)); + let s_3 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 1)); + _mm512_reduce_add_ps(_mm512_add_ps( + _mm512_add_ps(s_0, s_2), + _mm512_add_ps(s_1, s_3), + )) + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_x2_v4_avx512fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v4_avx512fp16(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v4")] @@ -464,11 +1017,186 @@ mod reduce_sum_of_x2 { #[cfg(all(target_arch = "x86_64", test))] #[test] #[cfg_attr(miri, ignore)] - fn reduce_sum_of_x2_v3_test() { + fn reduce_sum_of_x2_v3_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v3(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + #[target_feature(enable = "fp16")] + fn reduce_sum_of_x2_a2_fp16(this: &[f16]) -> f32 { + use crate::emulate::partial_load; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum_0 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_1 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_2 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_3 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_4 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_5 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_6 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_7 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + while n >= 64 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(32).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(40).cast()) }); + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(48).cast()) }); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(56).cast()) }); + sum_0 = vfmaq_f16(sum_0, x_0, x_0); + sum_1 = vfmaq_f16(sum_1, x_1, x_1); + sum_2 = vfmaq_f16(sum_2, x_2, x_2); + sum_3 = vfmaq_f16(sum_3, x_3, x_3); + sum_4 = vfmaq_f16(sum_4, x_4, x_4); + sum_5 = vfmaq_f16(sum_5, x_5, x_5); + sum_6 = vfmaq_f16(sum_6, x_6, x_6); + sum_7 = vfmaq_f16(sum_7, x_7, x_7); + (n, a) = unsafe { (n - 64, a.add(64)) }; + } + if n >= 32 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + sum_0 = vfmaq_f16(sum_0, x_0, x_0); + sum_1 = vfmaq_f16(sum_1, x_1, x_1); + sum_2 = vfmaq_f16(sum_2, x_2, x_2); + sum_3 = vfmaq_f16(sum_3, x_3, x_3); + (n, a) = unsafe { (n - 32, a.add(32)) }; + } + if n >= 16 { + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + sum_4 = vfmaq_f16(sum_4, x_4, x_4); + sum_5 = vfmaq_f16(sum_5, x_5, x_5); + (n, a) = unsafe { (n - 16, a.add(16)) }; + } + if n >= 8 { + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + sum_6 = vfmaq_f16(sum_6, x_6, x_6); + (n, a) = unsafe { (n - 8, a.add(8)) }; + } + if n > 0 { + let (_a,) = unsafe { partial_load!(8, n, a) }; + (a,) = (_a.as_ptr(),); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.cast()) }); + sum_7 = vfmaq_f16(sum_7, x_7, x_7); + } + let s_0 = vcvt_f32_f16(vget_low_f16(sum_0)); + let s_1 = vcvt_f32_f16(vget_high_f16(sum_0)); + let s_2 = vcvt_f32_f16(vget_low_f16(sum_1)); + let s_3 = vcvt_f32_f16(vget_high_f16(sum_1)); + let s_4 = vcvt_f32_f16(vget_low_f16(sum_2)); + let s_5 = vcvt_f32_f16(vget_high_f16(sum_2)); + let s_6 = vcvt_f32_f16(vget_low_f16(sum_3)); + let s_7 = vcvt_f32_f16(vget_high_f16(sum_3)); + let s_8 = vcvt_f32_f16(vget_low_f16(sum_4)); + let s_9 = vcvt_f32_f16(vget_high_f16(sum_4)); + let s_a = vcvt_f32_f16(vget_low_f16(sum_5)); + let s_b = vcvt_f32_f16(vget_high_f16(sum_5)); + let s_c = vcvt_f32_f16(vget_low_f16(sum_6)); + let s_d = vcvt_f32_f16(vget_high_f16(sum_6)); + let s_e = vcvt_f32_f16(vget_low_f16(sum_7)); + let s_f = vcvt_f32_f16(vget_high_f16(sum_7)); + let s = vpaddq_f32( + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_0, s_1), vpaddq_f32(s_2, s_3)), + vpaddq_f32(vpaddq_f32(s_4, s_5), vpaddq_f32(s_6, s_7)), + ), + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_8, s_9), vpaddq_f32(s_a, s_b)), + vpaddq_f32(vpaddq_f32(s_c, s_d), vpaddq_f32(s_e, s_f)), + ), + ); + vaddvq_f32(s) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_x2_a2_fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("a2") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (a2:fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_a2_fp16(this) }; + let fallback = fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_x2_a2(this: &[f16]) -> f32 { + use crate::emulate::{emulate_vreinterpret_f16_u16, partial_load}; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + sum = vfmaq_f32(sum, x, x); + (n, a) = unsafe { (n - 4, a.add(4)) }; + } + if n > 0 { + let (_a,) = unsafe { partial_load!(4, n, a) }; + (a,) = (_a.as_ptr(),); + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + sum = vfmaq_f32(sum, x, x); + } + vaddvq_f32(sum) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_x2_a2_test() { use rand::Rng; const EPSILON: f32 = 2.0; - if !crate::is_cpu_detected!("v3") { - println!("test {} ... skipped (v3)", module_path!()); + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); return; } let mut rng = rand::rng(); @@ -479,7 +1207,7 @@ mod reduce_sum_of_x2 { .collect::>(); for z in 3984..4016 { let this = &this[..z]; - let specialized = unsafe { reduce_sum_of_x2_v3(this) }; + let specialized = unsafe { reduce_sum_of_x2_a2(this) }; let fallback = fallback(this); assert!( (specialized - fallback).abs() < EPSILON, @@ -490,7 +1218,7 @@ mod reduce_sum_of_x2 { } #[crate::multiversion( - @"v4", @"v3", "v2", "a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" + @"v4:avx512fp16", @"v4", @"v3", "v2", @"a2:fp16", @"a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" )] pub fn reduce_sum_of_x2(this: &[f16]) -> f32 { let n = this.len(); @@ -505,6 +1233,67 @@ mod reduce_sum_of_x2 { mod reduce_min_max_of_x { use crate::{F16, f16}; + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + fn reduce_min_max_of_x_v4_avx512fp16(this: &[f16]) -> (f32, f32) { + use core::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = _mm512_cvtepi16_ph(_mm512_set1_epi16(f16::INFINITY.to_bits() as _)); + let mut max = _mm512_cvtepi16_ph(_mm512_set1_epi16(f16::NEG_INFINITY.to_bits() as _)); + while n >= 32 { + let x = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.cast())) }; + min = _mm512_min_ph(x, min); + max = _mm512_max_ph(x, max); + (n, a) = unsafe { (n - 32, a.add(32)) }; + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())) }; + min = _mm512_mask_min_ph(min, mask, x, min); + max = _mm512_mask_max_ph(max, mask, x, max); + } + let min = { + let s_0 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(min), 0)); + let s_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(min), 1)); + _mm512_reduce_min_ps(_mm512_min_ps(s_0, s_1)) + }; + let max = { + let s_0 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(max), 0)); + let s_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(max), 1)); + _mm512_reduce_max_ps(_mm512_max_ps(s_0, s_1)) + }; + (min, max) + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_min_max_of_x_v4_avx512fp16_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let mut x = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + (x[0], x[1]) = (f16::NAN, -f16::NAN); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v4_avx512fp16(x) }; + let fallback = fallback(x); + assert_eq!(specialized.0, fallback.0); + assert_eq!(specialized.1, fallback.1); + } + } + } + #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v4")] @@ -613,8 +1402,122 @@ mod reduce_min_max_of_x { } } + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + #[target_feature(enable = "fp16")] + fn reduce_min_max_of_x_a2_fp16(this: &[f16]) -> (f32, f32) { + use crate::emulate::{emulate_vreinterpretq_f16_u16, partial_load}; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = emulate_vreinterpretq_f16_u16(vdupq_n_u16(0x7C00u16)); + let mut max = emulate_vreinterpretq_f16_u16(vdupq_n_u16(0xFC00u16)); + while n >= 8 { + let x = emulate_vreinterpretq_f16_u16(unsafe { vld1q_u16(a.cast()) }); + min = vminnmq_f16(x, min); + max = vmaxnmq_f16(x, max); + (n, a) = unsafe { (n - 8, a.add(8)) }; + } + if n > 0 { + let (_a,) = unsafe { partial_load!(8, n, a = f16::NAN) }; + (a,) = (_a.as_ptr(),); + let x = emulate_vreinterpretq_f16_u16(unsafe { vld1q_u16(a.cast()) }); + min = vminnmq_f16(x, min); + max = vmaxnmq_f16(x, max); + } + ( + vminnmvq_f32(vcvt_f32_f16(vminnm_f16( + vget_low_f16(min), + vget_high_f16(min), + ))), + vmaxnmvq_f32(vcvt_f32_f16(vmaxnm_f16( + vget_low_f16(max), + vget_high_f16(max), + ))), + ) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_min_max_of_x_a2_fp16_test() { + use rand::Rng; + if !crate::is_cpu_detected!("a2") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (a2:fp16)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let mut x = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + (x[0], x[1]) = (f16::NAN, -f16::NAN); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_a2_fp16(x) }; + let fallback = fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + fn reduce_min_max_of_x_a2(this: &[f16]) -> (f32, f32) { + use crate::emulate::{emulate_vreinterpret_f16_u16, partial_load}; + use core::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = vdupq_n_f32(f32::INFINITY); + let mut max = vdupq_n_f32(f32::NEG_INFINITY); + while n >= 4 { + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + min = vminnmq_f32(x, min); + max = vmaxnmq_f32(x, max); + (n, a) = unsafe { (n - 4, a.add(4)) }; + } + if n > 0 { + let (_a,) = unsafe { partial_load!(4, n, a = f16::NAN) }; + (a,) = (_a.as_ptr(),); + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + min = vminnmq_f32(x, min); + max = vmaxnmq_f32(x, max); + } + (vminnmvq_f32(min), vmaxnmvq_f32(max)) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + fn reduce_min_max_of_x_a2_test() { + use rand::Rng; + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let mut x = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + (x[0], x[1]) = (f16::NAN, -f16::NAN); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_a2(x) }; + let fallback = fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + #[crate::multiversion( - @"v4", @"v3", "v2", "a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" + @"v4:avx512fp16", @"v4", @"v3", "v2", @"a2:fp16", @"a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1" )] pub fn reduce_min_max_of_x(this: &[f16]) -> (f32, f32) { let mut min = f32::INFINITY; @@ -638,15 +1541,42 @@ mod reduce_sum_of_xy { #[crate::target_cpu(enable = "v4")] #[target_feature(enable = "avx512fp16")] fn reduce_sum_of_xy_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { - unsafe extern "C" { - #[link_name = "fp16_reduce_sum_of_xy_v4_avx512fp16"] - unsafe fn f(n: usize, a: *const f16, b: *const f16) -> f32; - } assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let a = lhs.as_ptr(); - let b = rhs.as_ptr(); - unsafe { f(n, a, b) } + use core::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut _0 = _mm512_setzero_ph(); + let mut _1 = _mm512_setzero_ph(); + while n >= 64 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + let y_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(b.add(0).cast())) }; + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(32).cast())) }; + let y_1 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(b.add(32).cast())) }; + _0 = _mm512_fmadd_ph(x_0, y_0, _0); + _1 = _mm512_fmadd_ph(x_1, y_1, _1); + (n, a, b) = unsafe { (n - 64, a.add(64), b.add(64)) }; + } + if n >= 32 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + let y_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(b.add(0).cast())) }; + _0 = _mm512_fmadd_ph(x_0, y_0, _0); + (n, a, b) = unsafe { (n - 32, a.add(32), b.add(32)) }; + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())) }; + let y_1 = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b.cast())) }; + _1 = _mm512_fmadd_ph(x_1, y_1, _1); + } + let s_0 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 0)); + let s_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 1)); + let s_2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 0)); + let s_3 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 1)); + _mm512_reduce_add_ps(_mm512_add_ps( + _mm512_add_ps(s_0, s_2), + _mm512_add_ps(s_1, s_3), + )) } #[cfg(all(target_arch = "x86_64", test))] @@ -849,15 +1779,111 @@ mod reduce_sum_of_xy { #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "fp16")] fn reduce_sum_of_xy_a2_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { - unsafe extern "C" { - #[link_name = "fp16_reduce_sum_of_xy_a2_fp16"] - unsafe fn f(n: usize, a: *const f16, b: *const f16) -> f32; - } + use crate::emulate::partial_load; + use core::arch::aarch64::*; assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let a = lhs.as_ptr(); - let b = rhs.as_ptr(); - unsafe { f(n, a, b) } + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut sum_0 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_1 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_2 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_3 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_4 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_5 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_6 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_7 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + while n >= 64 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(32).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(40).cast()) }); + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(48).cast()) }); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(56).cast()) }); + let y_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + let y_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(8).cast()) }); + let y_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(16).cast()) }); + let y_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(24).cast()) }); + let y_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(32).cast()) }); + let y_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(40).cast()) }); + let y_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(48).cast()) }); + let y_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(56).cast()) }); + sum_0 = vfmaq_f16(sum_0, x_0, y_0); + sum_1 = vfmaq_f16(sum_1, x_1, y_1); + sum_2 = vfmaq_f16(sum_2, x_2, y_2); + sum_3 = vfmaq_f16(sum_3, x_3, y_3); + sum_4 = vfmaq_f16(sum_4, x_4, y_4); + sum_5 = vfmaq_f16(sum_5, x_5, y_5); + sum_6 = vfmaq_f16(sum_6, x_6, y_6); + sum_7 = vfmaq_f16(sum_7, x_7, y_7); + (n, a, b) = unsafe { (n - 64, a.add(64), b.add(64)) }; + } + if n >= 32 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + let y_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + let y_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(8).cast()) }); + let y_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(16).cast()) }); + let y_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(24).cast()) }); + sum_0 = vfmaq_f16(sum_0, x_0, y_0); + sum_1 = vfmaq_f16(sum_1, x_1, y_1); + sum_2 = vfmaq_f16(sum_2, x_2, y_2); + sum_3 = vfmaq_f16(sum_3, x_3, y_3); + (n, a, b) = unsafe { (n - 32, a.add(32), b.add(32)) }; + } + if n >= 16 { + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let y_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + let y_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(8).cast()) }); + sum_4 = vfmaq_f16(sum_4, x_4, y_4); + sum_5 = vfmaq_f16(sum_5, x_5, y_5); + (n, a, b) = unsafe { (n - 16, a.add(16), b.add(16)) }; + } + if n >= 8 { + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let y_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + sum_6 = vfmaq_f16(sum_6, x_6, y_6); + (n, a, b) = unsafe { (n - 8, a.add(8), b.add(8)) }; + } + if n > 0 { + let (_a, _b) = unsafe { partial_load!(8, n, a, b) }; + (a, b) = (_a.as_ptr(), _b.as_ptr()); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.cast()) }); + let y_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.cast()) }); + sum_7 = vfmaq_f16(sum_7, x_7, y_7); + } + let s_0 = vcvt_f32_f16(vget_low_f16(sum_0)); + let s_1 = vcvt_f32_f16(vget_high_f16(sum_0)); + let s_2 = vcvt_f32_f16(vget_low_f16(sum_1)); + let s_3 = vcvt_f32_f16(vget_high_f16(sum_1)); + let s_4 = vcvt_f32_f16(vget_low_f16(sum_2)); + let s_5 = vcvt_f32_f16(vget_high_f16(sum_2)); + let s_6 = vcvt_f32_f16(vget_low_f16(sum_3)); + let s_7 = vcvt_f32_f16(vget_high_f16(sum_3)); + let s_8 = vcvt_f32_f16(vget_low_f16(sum_4)); + let s_9 = vcvt_f32_f16(vget_high_f16(sum_4)); + let s_a = vcvt_f32_f16(vget_low_f16(sum_5)); + let s_b = vcvt_f32_f16(vget_high_f16(sum_5)); + let s_c = vcvt_f32_f16(vget_low_f16(sum_6)); + let s_d = vcvt_f32_f16(vget_high_f16(sum_6)); + let s_e = vcvt_f32_f16(vget_low_f16(sum_7)); + let s_f = vcvt_f32_f16(vget_high_f16(sum_7)); + let s = vpaddq_f32( + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_0, s_1), vpaddq_f32(s_2, s_3)), + vpaddq_f32(vpaddq_f32(s_4, s_5), vpaddq_f32(s_6, s_7)), + ), + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_8, s_9), vpaddq_f32(s_a, s_b)), + vpaddq_f32(vpaddq_f32(s_c, s_d), vpaddq_f32(s_e, s_f)), + ), + ); + vaddvq_f32(s) } #[cfg(all(target_arch = "aarch64", test))] @@ -892,7 +1918,70 @@ mod reduce_sum_of_xy { } } - #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", #[cfg(target_endian = "little")] @"a3.512", @"a2:fp16", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1")] + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_xy_a2(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::emulate::{emulate_vreinterpret_f16_u16, partial_load}; + use core::arch::aarch64::*; + assert!(lhs.len() == rhs.len()); + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + let y = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(b.cast()) }); + let y = vcvt_f32_f16(y); + sum = vfmaq_f32(sum, x, y); + (n, a, b) = unsafe { (n - 4, a.add(4), b.add(4)) }; + } + if n > 0 { + let (_a, _b) = unsafe { partial_load!(4, n, a, b) }; + (a, b) = (_a.as_ptr(), _b.as_ptr()); + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + let y = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(b.cast()) }); + let y = vcvt_f32_f16(y); + sum = vfmaq_f32(sum, x, y); + } + vaddvq_f32(sum) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_xy_a2_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_a2(lhs, rhs) }; + let fallback = fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", #[cfg(target_endian = "little")] @"a3.512", @"a2:fp16", @"a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1")] pub fn reduce_sum_of_xy(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -914,15 +2003,46 @@ mod reduce_sum_of_d2 { #[crate::target_cpu(enable = "v4")] #[target_feature(enable = "avx512fp16")] fn reduce_sum_of_d2_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { - unsafe extern "C" { - #[link_name = "fp16_reduce_sum_of_d2_v4_avx512fp16"] - unsafe fn f(n: usize, a: *const f16, b: *const f16) -> f32; - } assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let a = lhs.as_ptr(); - let b = rhs.as_ptr(); - unsafe { f(n, a, b) } + use core::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut _0 = _mm512_setzero_ph(); + let mut _1 = _mm512_setzero_ph(); + while n >= 64 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + let y_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(b.add(0).cast())) }; + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(32).cast())) }; + let y_1 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(b.add(32).cast())) }; + let d_0 = _mm512_sub_ph(x_0, y_0); + let d_1 = _mm512_sub_ph(x_1, y_1); + _0 = _mm512_fmadd_ph(d_0, d_0, _0); + _1 = _mm512_fmadd_ph(d_1, d_1, _1); + (n, a, b) = unsafe { (n - 64, a.add(64), b.add(64)) }; + } + if n >= 32 { + let x_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(a.add(0).cast())) }; + let y_0 = unsafe { _mm512_castsi512_ph(_mm512_loadu_epi16(b.add(0).cast())) }; + let d_0 = _mm512_sub_ph(x_0, y_0); + _0 = _mm512_fmadd_ph(d_0, d_0, _0); + (n, a, b) = unsafe { (n - 32, a.add(32), b.add(32)) }; + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x_1 = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())) }; + let y_1 = unsafe { _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b.cast())) }; + let d_1 = _mm512_sub_ph(x_1, y_1); + _1 = _mm512_fmadd_ph(d_1, d_1, _1); + } + let s_0 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 0)); + let s_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_0), 1)); + let s_2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 0)); + let s_3 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(_mm512_castph_si512(_1), 1)); + _mm512_reduce_add_ps(_mm512_add_ps( + _mm512_add_ps(s_0, s_2), + _mm512_add_ps(s_1, s_3), + )) } #[cfg(all(target_arch = "x86_64", test))] @@ -1134,15 +2254,127 @@ mod reduce_sum_of_d2 { #[crate::target_cpu(enable = "a2")] #[target_feature(enable = "fp16")] fn reduce_sum_of_d2_a2_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { - unsafe extern "C" { - #[link_name = "fp16_reduce_sum_of_d2_a2_fp16"] - unsafe fn f(n: usize, a: *const f16, b: *const f16) -> f32; - } + use crate::emulate::partial_load; + use core::arch::aarch64::*; assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let a = lhs.as_ptr().cast(); - let b = rhs.as_ptr().cast(); - unsafe { f(n, a, b) } + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut sum_0 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_1 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_2 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_3 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_4 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_5 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_6 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + let mut sum_7 = vreinterpretq_f16_u16(vdupq_n_u16(0)); + while n >= 64 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(32).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(40).cast()) }); + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(48).cast()) }); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(56).cast()) }); + let y_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + let y_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(8).cast()) }); + let y_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(16).cast()) }); + let y_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(24).cast()) }); + let y_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(32).cast()) }); + let y_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(40).cast()) }); + let y_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(48).cast()) }); + let y_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(56).cast()) }); + let d_0 = vsubq_f16(x_0, y_0); + let d_1 = vsubq_f16(x_1, y_1); + let d_2 = vsubq_f16(x_2, y_2); + let d_3 = vsubq_f16(x_3, y_3); + let d_4 = vsubq_f16(x_4, y_4); + let d_5 = vsubq_f16(x_5, y_5); + let d_6 = vsubq_f16(x_6, y_6); + let d_7 = vsubq_f16(x_7, y_7); + sum_0 = vfmaq_f16(sum_0, d_0, d_0); + sum_1 = vfmaq_f16(sum_1, d_1, d_1); + sum_2 = vfmaq_f16(sum_2, d_2, d_2); + sum_3 = vfmaq_f16(sum_3, d_3, d_3); + sum_4 = vfmaq_f16(sum_4, d_4, d_4); + sum_5 = vfmaq_f16(sum_5, d_5, d_5); + sum_6 = vfmaq_f16(sum_6, d_6, d_6); + sum_7 = vfmaq_f16(sum_7, d_7, d_7); + (n, a, b) = unsafe { (n - 64, a.add(64), b.add(64)) }; + } + if n >= 32 { + let x_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let x_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(16).cast()) }); + let x_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(24).cast()) }); + let y_0 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + let y_1 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(8).cast()) }); + let y_2 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(16).cast()) }); + let y_3 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(24).cast()) }); + let d_0 = vsubq_f16(x_0, y_0); + let d_1 = vsubq_f16(x_1, y_1); + let d_2 = vsubq_f16(x_2, y_2); + let d_3 = vsubq_f16(x_3, y_3); + sum_0 = vfmaq_f16(sum_0, d_0, d_0); + sum_1 = vfmaq_f16(sum_1, d_1, d_1); + sum_2 = vfmaq_f16(sum_2, d_2, d_2); + sum_3 = vfmaq_f16(sum_3, d_3, d_3); + (n, a, b) = unsafe { (n - 32, a.add(32), b.add(32)) }; + } + if n >= 16 { + let x_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let x_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(8).cast()) }); + let y_4 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + let y_5 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(8).cast()) }); + let d_4 = vsubq_f16(x_4, y_4); + let d_5 = vsubq_f16(x_5, y_5); + sum_4 = vfmaq_f16(sum_4, d_4, d_4); + sum_5 = vfmaq_f16(sum_5, d_5, d_5); + (n, a, b) = unsafe { (n - 16, a.add(16), b.add(16)) }; + } + if n >= 8 { + let x_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.add(0).cast()) }); + let y_6 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.add(0).cast()) }); + let d_6 = vsubq_f16(x_6, y_6); + sum_6 = vfmaq_f16(sum_6, d_6, d_6); + (n, a, b) = unsafe { (n - 8, a.add(8), b.add(8)) }; + } + if n > 0 { + let (_a, _b) = unsafe { partial_load!(8, n, a, b) }; + (a, b) = (_a.as_ptr(), _b.as_ptr()); + let x_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(a.cast()) }); + let y_7 = vreinterpretq_f16_u16(unsafe { vld1q_u16(b.cast()) }); + let d_7 = vsubq_f16(x_7, y_7); + sum_7 = vfmaq_f16(sum_7, d_7, d_7); + } + let s_0 = vcvt_f32_f16(vget_low_f16(sum_0)); + let s_1 = vcvt_f32_f16(vget_high_f16(sum_0)); + let s_2 = vcvt_f32_f16(vget_low_f16(sum_1)); + let s_3 = vcvt_f32_f16(vget_high_f16(sum_1)); + let s_4 = vcvt_f32_f16(vget_low_f16(sum_2)); + let s_5 = vcvt_f32_f16(vget_high_f16(sum_2)); + let s_6 = vcvt_f32_f16(vget_low_f16(sum_3)); + let s_7 = vcvt_f32_f16(vget_high_f16(sum_3)); + let s_8 = vcvt_f32_f16(vget_low_f16(sum_4)); + let s_9 = vcvt_f32_f16(vget_high_f16(sum_4)); + let s_a = vcvt_f32_f16(vget_low_f16(sum_5)); + let s_b = vcvt_f32_f16(vget_high_f16(sum_5)); + let s_c = vcvt_f32_f16(vget_low_f16(sum_6)); + let s_d = vcvt_f32_f16(vget_high_f16(sum_6)); + let s_e = vcvt_f32_f16(vget_low_f16(sum_7)); + let s_f = vcvt_f32_f16(vget_high_f16(sum_7)); + let s = vpaddq_f32( + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_0, s_1), vpaddq_f32(s_2, s_3)), + vpaddq_f32(vpaddq_f32(s_4, s_5), vpaddq_f32(s_6, s_7)), + ), + vpaddq_f32( + vpaddq_f32(vpaddq_f32(s_8, s_9), vpaddq_f32(s_a, s_b)), + vpaddq_f32(vpaddq_f32(s_c, s_d), vpaddq_f32(s_e, s_f)), + ), + ); + vaddvq_f32(s) } #[cfg(all(target_arch = "aarch64", test))] @@ -1177,7 +2409,72 @@ mod reduce_sum_of_d2 { } } - #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", #[cfg(target_endian = "little")] @"a3.512", @"a2:fp16", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1")] + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "a2")] + fn reduce_sum_of_d2_a2(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::emulate::{emulate_vreinterpret_f16_u16, partial_load}; + use core::arch::aarch64::*; + assert!(lhs.len() == rhs.len()); + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + let y = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(b.cast()) }); + let y = vcvt_f32_f16(y); + let d = vsubq_f32(x, y); + sum = vfmaq_f32(sum, d, d); + (n, a, b) = unsafe { (n - 4, a.add(4), b.add(4)) }; + } + if n > 0 { + let (_a, _b) = unsafe { partial_load!(4, n, a, b) }; + (a, b) = (_a.as_ptr(), _b.as_ptr()); + let x = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(a.cast()) }); + let x = vcvt_f32_f16(x); + let y = emulate_vreinterpret_f16_u16(unsafe { vld1_u16(b.cast()) }); + let y = vcvt_f32_f16(y); + let d = vsubq_f32(x, y); + sum = vfmaq_f32(sum, d, d); + } + vaddvq_f32(sum) + } + + #[cfg(all(target_arch = "aarch64", test))] + #[test] + #[cfg_attr(miri, ignore)] + fn reduce_sum_of_d2_a2_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("a2") { + println!("test {} ... skipped (a2)", module_path!()); + return; + } + let mut rng = rand::rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::_from_f32(rng.random_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_a2(lhs, rhs) }; + let fallback = fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", #[cfg(target_endian = "little")] @"a3.512", @"a2:fp16", @"a2", "z17", "z16", "z15", "z14", "z13", "p9", "p8", "p7", "r1")] pub fn reduce_sum_of_d2(lhs: &[f16], rhs: &[f16]) -> f32 { assert!(lhs.len() == rhs.len()); let n = lhs.len();