From 2be97efd80c9cb40362938aad36c374cbeae20d9 Mon Sep 17 00:00:00 2001 From: Xie Gengxin Date: Thu, 20 Feb 2025 16:44:08 +0800 Subject: [PATCH] feat: better memcmp implementation for avx2 --- benchmark/sonic.hpp | 3 +- include/sonic/dom/dynamicnode.h | 83 +++++++- include/sonic/internal/arch/avx2/base.h | 199 ++++++++++++++++++ .../internal/arch/common/arm_common/base.h | 9 + include/sonic/internal/arch/neon/base.h | 2 + include/sonic/internal/arch/simd_base.h | 2 + include/sonic/internal/arch/sse/base.h | 9 + include/sonic/internal/arch/sve2-128/base.h | 2 + tests/memcmp_test.cpp | 162 ++++++++++++++ 9 files changed, 465 insertions(+), 6 deletions(-) create mode 100644 tests/memcmp_test.cpp diff --git a/benchmark/sonic.hpp b/benchmark/sonic.hpp index 065f1ae..fb36e85 100644 --- a/benchmark/sonic.hpp +++ b/benchmark/sonic.hpp @@ -70,7 +70,8 @@ class SonicParseResult : public ParseResult, switch (v.GetType()) { case sonic_json::kObject: for (auto m = v.MemberBegin(); m != v.MemberEnd(); ++m) { - auto re = v.FindMember(m->name.GetStringView()); + auto name_sv = m->name.GetStringView(); + auto re = v.FindMember(name_sv.data(), name_sv.size()); if (re != v.MemberEnd()) { stat.members++; find_value(re->value, stat); diff --git a/include/sonic/dom/dynamicnode.h b/include/sonic/dom/dynamicnode.h index ad1c61a..a5918e8 100644 --- a/include/sonic/dom/dynamicnode.h +++ b/include/sonic/dom/dynamicnode.h @@ -232,6 +232,39 @@ class DNode : public GenericNode> { using BaseNode::operator[]; using BaseNode::FindMember; + + /** + * @brief Find a specific member in an object. A member is a pair node of key + * and name. + * @param key target name pointer + * @param len target name length + * @retval MemberEnd() not found + * @retval others iterator for found member + * @note If target name is a literal string, string_view can be optimized by + * compiler. This function will provide a better memcmp implemention than + * std::memcmp while length is not too large. + */ + sonic_force_inline MemberIterator FindMember(const char* key, + size_t len) noexcept { + return findMemberImpl(key, len); + } + + /** + * @brief Find a specific member in an object. A member is a pair node of key + * and name. + * @param key target name pointer + * @param len target name length + * @retval MemberEnd() not found + * @retval others iterator for found member + * @note If target name is a literal string, string_view can be optimized by + * compiler. This function will provide a better memcmp implemention than + * std::memcmp while length is not too large. + */ + sonic_force_inline ConstMemberIterator FindMember(const char* key, + size_t len) const noexcept { + return findMemberImpl(key, len); + } + using BaseNode::HasMember; /** @@ -301,7 +334,19 @@ class DNode : public GenericNode> { private: using MSType = StringView; using MAType = MapAllocator, Allocator>; +#if defined(SONIC_STATIC_DISPATCH) + struct Less { + bool operator()(MSType s1, MSType s2) const { + size_t n1 = s1.size(), n2 = s2.size(); + const size_t len = std::min(n1, n2); + int cmp = internal::InlinedMemcmp(s1.data(), s2.data(), len); + return cmp < 0 || (cmp == 0 && n1 < n2); + } + }; + using map_type = std::multimap; +#else using map_type = std::multimap, MAType>; +#endif struct MetaNode { size_t cap; @@ -565,13 +610,17 @@ class DNode : public GenericNode> { return ((MetaNode*)(this->o.next.children))->map; } + sonic_force_inline MemberIterator findFromMap(StringView key) const { + auto it = getMap()->find(MSType(key.data(), key.size())); + if (it != getMap()->end()) { + return memberBeginUnsafe() + it->second; + } + return memberEndUnsafe(); + } + sonic_force_inline MemberIterator findMemberImpl(StringView key) const { if (nullptr != getMap()) { - auto it = getMap()->find(MSType(key.data(), key.size())); - if (it != getMap()->end()) { - return memberBeginUnsafe() + it->second; - } - return memberEndUnsafe(); + return findFromMap(key); } auto it = this->MemberBegin(); for (auto e = this->MemberEnd(); it != e; ++it) { @@ -582,6 +631,30 @@ class DNode : public GenericNode> { return const_cast(it); } + sonic_force_inline MemberIterator findMemberImpl(const char* key, + size_t len) const { + /************************************************** + * Only calling internal memcmp when static dispatch. + * Dynamic dispatch will have indirect call. + **************************************************/ +#if defined(SONIC_STATIC_DISPATCH) + if (nullptr != getMap()) { + return findFromMap(StringView(key, len)); + } + auto it = this->MemberBegin(); + for (auto e = this->MemberEnd(); it != e; ++it) { + auto name_sv = it->name.GetStringView(); + if (name_sv.size() == len && + internal::InlinedMemcmpEq(name_sv.data(), key, len)) { + break; + } + } + return const_cast(it); +#else + return findMemberImpl(StringView(key, len)); +#endif + } + sonic_force_inline DNode& findValueImpl(StringView key) const noexcept { auto m = findMemberImpl(key); if (m != this->MemberEnd()) { diff --git a/include/sonic/internal/arch/avx2/base.h b/include/sonic/internal/arch/avx2/base.h index 620ce70..05a2125 100644 --- a/include/sonic/internal/arch/avx2/base.h +++ b/include/sonic/internal/arch/avx2/base.h @@ -19,8 +19,32 @@ #include +#include + #include "simd.h" +#ifdef __GNUC__ +#if defined(__SANITIZE_THREAD__) || defined(__SANITIZE_ADDRESS__) || \ + defined(__SANITIZE_LEAK__) || defined(__SANITIZE_UNDEFINED__) +#ifndef SONIC_USE_SANITIZE +#define SONIC_USE_SANITIZE +#endif +#endif +#endif + +#if defined(__clang__) +#if defined(__has_feature) +#if __has_feature(address_sanitizer) || __has_feature(thread_sanitizer) || \ + __has_feature(memory_sanitizer) || \ + __has_feature(undefined_behavior_sanitizer) || \ + __has_feature(leak_sanitizer) +#ifndef SONIC_USE_SANITIZE +#define SONIC_USE_SANITIZE +#endif +#endif +#endif +#endif + SONIC_PUSH_HASWELL namespace sonic_json { @@ -157,6 +181,181 @@ sonic_force_inline void Xmemcpy<16>(void* dst_, const void* src_, } } +namespace { +static sonic_force_inline bool in_page_32(const void* a, const void* b) { +#ifdef SONIC_USE_SANITIZE + (void)a; + (void)b; + return false; +#else + static constexpr size_t VecLen = 32; + static constexpr size_t PageSize = 4096; + size_t addr = (size_t)(a) | (size_t)(b); + return ((addr) & (PageSize - 1)) <= (PageSize - VecLen); +#endif +} + +static sonic_force_inline int cmp_lt_32(const void* _l, const void* _r, + size_t s) { + auto lhs = static_cast(_l); + auto rhs = static_cast(_r); + if (in_page_32(lhs, rhs)) { + __m256i vec_l = _mm256_loadu_si256((__m256i const*)rhs); + __m256i vec_r = _mm256_loadu_si256((__m256i const*)lhs); + __m256i ans = _mm256_cmpeq_epi8(vec_l, vec_r); + int mask = _mm256_movemask_epi8(ans) + 1; + // mask = mask << (32 -s); + __asm__("bzhil %1, %2, %[result]\n\t" + : [result] "=r"(mask) + : "r"((int)s), "r"(mask)); + if (mask) { + int ne_idx = __builtin_ctz(mask); + // if (lhs[ne_idx] < rhs[ne_idx]) return -1; + // else return 1; + return lhs[ne_idx] - rhs[ne_idx]; + } else { + return 0; + } + } + return std::memcmp(lhs, rhs, s); +} + +// slow path +static inline bool is_eq_lt_32_cross_page(const void* _a, const void* _b, + unsigned int s) { + auto a = static_cast(_a); + auto b = static_cast(_b); + if (s >= 16) { + __m128i vec_a = _mm_loadu_si128((__m128i const*)a); + __m128i vec_b = _mm_loadu_si128((__m128i const*)b); + __m128i ans1 = _mm_cmpeq_epi8(vec_a, vec_b); + + vec_a = _mm_loadu_si128((__m128i const*)(a + s - 16)); + vec_b = _mm_loadu_si128((__m128i const*)(b + s - 16)); + __m128i ans2 = _mm_cmpeq_epi8(vec_a, vec_b); + + __m128i ans = _mm_and_si128(ans1, ans2); + int mask = _mm_movemask_epi8(ans); + return mask == 0xFFFF; + } + // cross page + if (s >= 8) { + return __builtin_memcmp(a, b, 8) == 0 && + __builtin_memcmp(a + s - 8, b + s - 8, 8) == 0; + } else if (s >= 4) { + return __builtin_memcmp(a, b, 4) == 0 && + __builtin_memcmp(a + s - 4, b + s - 4, 4) == 0; + } else if (s >= 2) { + return __builtin_memcmp(a, b, 2) == 0 && + __builtin_memcmp(a + s - 2, b + s - 2, 2) == 0; + } else { + return *a == *b; + } + return true; +} + +static sonic_force_inline bool is_eq_lt_32(const void* _a, const void* _b, + size_t s) { + auto a = static_cast(_a); + auto b = static_cast(_b); + if (in_page_32(a, b)) { + __m256i vec_a = _mm256_loadu_si256((__m256i const*)a); + __m256i vec_b = _mm256_loadu_si256((__m256i const*)b); + __m256i ans = _mm256_cmpeq_epi8(vec_a, vec_b); + int mask = _mm256_movemask_epi8(ans) + 1; + // mask = mask << (32 -s); + __asm__("bzhil %1, %2, %[result]\n\t" + : [result] "=r"(mask) + : "r"((int)s), "r"(mask)); + return mask == 0; + } + return is_eq_lt_32_cross_page(a, b, s); +} +} // namespace + +sonic_force_inline bool InlinedMemcmpEq(const void* _a, const void* _b, + size_t s) { + auto a = static_cast(_a); + auto b = static_cast(_b); + if (s == 0) return true; + if (s < 32) return is_eq_lt_32(a, b, s); + size_t avx2_end = (s & (~31ULL)); + + __m256i vec_a = _mm256_loadu_si256((__m256i const*)(a)); + __m256i vec_b = _mm256_loadu_si256((__m256i const*)(b)); + __m256i ans_1 = _mm256_cmpeq_epi8(vec_a, vec_b); + // unsigned int mask = _mm256_movemask_epi8(ans_1) + 1; + // if (mask) return false; + + for (size_t i = 32; i < avx2_end; i += 32) { + vec_a = _mm256_loadu_si256((__m256i const*)(a + i)); + vec_b = _mm256_loadu_si256((__m256i const*)(b + i)); + __m256i ans = _mm256_cmpeq_epi8(vec_a, vec_b); + unsigned int mask = _mm256_movemask_epi8(ans) + 1; + if (mask) return false; + } + // no branch for s = x32 + // if (avx2_end == s) return true; + // s >= 32 overlap + { + vec_a = _mm256_loadu_si256((__m256i const*)(a + s - 32)); + vec_b = _mm256_loadu_si256((__m256i const*)(b + s - 32)); + __m256i ans = _mm256_cmpeq_epi8(vec_a, vec_b); + ans = _mm256_and_si256(ans, ans_1); + unsigned int mask = _mm256_movemask_epi8(ans) + 1; + if (mask) return false; + } + return true; +} + +sonic_force_inline int InlinedMemcmp(const void* _l, const void* _r, size_t s) { + auto lhs = static_cast(_l); + auto rhs = static_cast(_r); + if (s == 0) return 0; + if (s < 32) return cmp_lt_32(lhs, rhs, s); + size_t avx2_end = (s & (~31ULL)); + + __m256i vec_l = _mm256_loadu_si256((__m256i const*)(lhs)); + __m256i vec_r = _mm256_loadu_si256((__m256i const*)(rhs)); + __m256i ans_1 = _mm256_cmpeq_epi8(vec_l, vec_r); + uint32_t mask = static_cast(_mm256_movemask_epi8(ans_1)) + 1; + if (mask) { + int ne_idx = __builtin_ctz(mask); + // if (lhs[ne_idx] < rhs[ne_idx]) return -1; + // else return 1; + return lhs[ne_idx] - rhs[ne_idx]; + } + + for (size_t i = 32; i < avx2_end; i += 32) { + vec_l = _mm256_loadu_si256((__m256i const*)(lhs + i)); + vec_r = _mm256_loadu_si256((__m256i const*)(rhs + i)); + __m256i ans = _mm256_cmpeq_epi8(vec_l, vec_r); + mask = static_cast(_mm256_movemask_epi8(ans)) + 1; + if (mask) { + int ne_idx = __builtin_ctz(mask); + // if (lhs[i + ne_idx] < rhs[i + ne_idx]) return -1; + // else return 1; + return lhs[i + ne_idx] - rhs[i + ne_idx]; + } + } + // no branch for s = x32 + // if (avx2_end == s) return true; + // s >= 32 overlap + { + size_t offset = s - 32; + vec_l = _mm256_loadu_si256((__m256i const*)(lhs + offset)); + vec_r = _mm256_loadu_si256((__m256i const*)(rhs + offset)); + __m256i ans = _mm256_cmpeq_epi8(vec_l, vec_r); + // ans = _mm256_and_si256(ans, ans_1); + unsigned int mask = static_cast(_mm256_movemask_epi8(ans)) + 1; + if (mask) { + int ne_idx = __builtin_ctz(mask); + return lhs[offset + ne_idx] - rhs[offset + ne_idx]; + } + } + return 0; +} + } // namespace avx2 } // namespace internal } // namespace sonic_json diff --git a/include/sonic/internal/arch/common/arm_common/base.h b/include/sonic/internal/arch/common/arm_common/base.h index ea2f432..f07e05f 100644 --- a/include/sonic/internal/arch/common/arm_common/base.h +++ b/include/sonic/internal/arch/common/arm_common/base.h @@ -71,6 +71,15 @@ sonic_force_inline void Xmemcpy(void* dst_, const void* src_, size_t chunks) { std::memcpy(dst_, src_, chunks * ChunkSize); } +sonic_force_inline bool InlinedMemcmpEq(const void* _a, const void* _b, + size_t s) { + return std::memcmp(_a, _b, s) == 0; +} + +sonic_force_inline int InlinedMemcmp(const void* _l, const void* _r, size_t s) { + return std::memcmp(_l, _r, s); +} + } // namespace arm_common } // namespace internal } // namespace sonic_json diff --git a/include/sonic/internal/arch/neon/base.h b/include/sonic/internal/arch/neon/base.h index 954fe59..eb33723 100644 --- a/include/sonic/internal/arch/neon/base.h +++ b/include/sonic/internal/arch/neon/base.h @@ -25,6 +25,8 @@ namespace neon { using sonic_json::internal::arm_common::ClearLowestBit; using sonic_json::internal::arm_common::CountOnes; +using sonic_json::internal::arm_common::InlinedMemcmp; +using sonic_json::internal::arm_common::InlinedMemcmpEq; using sonic_json::internal::arm_common::LeadingZeroes; using sonic_json::internal::arm_common::PrefixXor; using sonic_json::internal::arm_common::TrailingZeroes; diff --git a/include/sonic/internal/arch/simd_base.h b/include/sonic/internal/arch/simd_base.h index 4cdf3f0..760a13c 100644 --- a/include/sonic/internal/arch/simd_base.h +++ b/include/sonic/internal/arch/simd_base.h @@ -29,6 +29,8 @@ SONIC_USING_ARCH_FUNC(LeadingZeroes); SONIC_USING_ARCH_FUNC(CountOnes); SONIC_USING_ARCH_FUNC(PrefixXor); SONIC_USING_ARCH_FUNC(Xmemcpy); +SONIC_USING_ARCH_FUNC(InlinedMemcmpEq); +SONIC_USING_ARCH_FUNC(InlinedMemcmp); } // namespace internal } // namespace sonic_json diff --git a/include/sonic/internal/arch/sse/base.h b/include/sonic/internal/arch/sse/base.h index 106f4bd..725af3c 100644 --- a/include/sonic/internal/arch/sse/base.h +++ b/include/sonic/internal/arch/sse/base.h @@ -94,6 +94,15 @@ sonic_force_inline void Xmemcpy<32>(void* dst_, const void* src_, Xmemcpy<16>(dst_, src_, chunks * 2); } +sonic_force_inline bool InlinedMemcmpEq(const void* _a, const void* _b, + size_t s) { + return std::memcmp(_a, _b, s) == 0; +} + +sonic_force_inline int InlinedMemcmp(const void* _l, const void* _r, size_t s) { + return std::memcmp(_l, _r, s); +} + } // namespace sse } // namespace internal } // namespace sonic_json diff --git a/include/sonic/internal/arch/sve2-128/base.h b/include/sonic/internal/arch/sve2-128/base.h index f3164f4..2b15bcc 100644 --- a/include/sonic/internal/arch/sve2-128/base.h +++ b/include/sonic/internal/arch/sve2-128/base.h @@ -25,6 +25,8 @@ namespace sve2_128 { using sonic_json::internal::arm_common::ClearLowestBit; using sonic_json::internal::arm_common::CountOnes; +using sonic_json::internal::arm_common::InlinedMemcmp; +using sonic_json::internal::arm_common::InlinedMemcmpEq; using sonic_json::internal::arm_common::LeadingZeroes; using sonic_json::internal::arm_common::PrefixXor; using sonic_json::internal::arm_common::TrailingZeroes; diff --git a/tests/memcmp_test.cpp b/tests/memcmp_test.cpp new file mode 100644 index 0000000..b81d26d --- /dev/null +++ b/tests/memcmp_test.cpp @@ -0,0 +1,162 @@ +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "include/sonic/internal/arch/avx2/base.h" +#include "include/sonic/internal/arch/sonic_cpu_feature.h" + +#if defined(SONIC_HAVE_AVX2) && !defined(SONIC_DYNAMIC_DISPATCH) +namespace { + +using namespace sonic_json::internal::avx2; + +static std::string random_string(int str_len) { + // const char * strs = + // "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&*()这是一个字符串"; + std::string re; + + std::random_device rd; + std::mt19937 gen(rd()); + for (int i = 0; i < str_len; ++i) { + char c = gen() % 26 + 'a'; + re.append(1, c); + } + return re; +} + +bool is_correct(int a, int b) { + if (a < 0) return b < 0; + if (a > 0) return b > 0; + return a == b; +} + +TEST(InlinedMemcmp, Basic) { + EXPECT_EQ(0, InlinedMemcmp("", "", 0)); + EXPECT_EQ(0, InlinedMemcmp("123", "1", 0)); + EXPECT_EQ(0, InlinedMemcmp("123", "1", 1)); + EXPECT_EQ(-1, InlinedMemcmp("12345678901234567890123456789012345", + "22345678901234567890123456789012345", 35)); + for (int i = 0; i < 1024; ++i) { + std::string str1 = random_string(i); + std::string str2 = random_string(i); + EXPECT_EQ(str1.size(), str2.size()); + EXPECT_TRUE( + is_correct(std::memcmp(str1.data(), str2.data(), str1.size()), + InlinedMemcmp(str1.data(), str2.data(), str1.size()))) + << "str1 is: " << str1 << std::endl + << "str2 is: " << str2 << std::endl + << "std::memcmp is: " + << std::memcmp(str1.data(), str2.data(), str1.size()) << std::endl + << "InlinedMemcmp is: " + << InlinedMemcmp(str1.data(), str2.data(), str1.size()) << std::endl; + } + + for (int i = 1; i <= 1024; ++i) { + std::string str = random_string(i); + for (int j = 0; j < i; ++j) { + std::string str1 = str; + std::string str2 = str; + EXPECT_EQ(0, InlinedMemcmp(str1.data(), str2.data(), str1.size())); + str1[j] = '1'; + str2[j] = '2'; + EXPECT_TRUE(InlinedMemcmp(str1.data(), str2.data(), str1.size()) < 0); + str1[j] = '2'; + str2[j] = '1'; + EXPECT_TRUE(InlinedMemcmp(str1.data(), str2.data(), str1.size()) > 0); + } + } +} + +TEST(InlinedMemcmp, CrossPage) { + for (int i = 1; i <= 1024; ++i) { + std::string str = random_string(i); + auto a_ptr = std::unique_ptr( + static_cast(aligned_alloc(4096, 4096 * 2)), + [](char* ptr) { free(ptr); }); + auto b_ptr = std::unique_ptr(new char[i]); + char* a = a_ptr.get() + 4095; + char* b = b_ptr.get(); + for (int j = 0; j < i; ++j) { + std::memcpy(a, str.data(), i); + std::memcpy(b, str.data(), i); + EXPECT_EQ(0, InlinedMemcmp(a, b, i)); + a[j] = '1'; + b[j] = '2'; + EXPECT_TRUE(InlinedMemcmp(a, b, i) < 0); + a[j] = '2'; + b[j] = '1'; + EXPECT_TRUE(InlinedMemcmp(a, b, i) > 0); + } + } +} + +void success_helper(const void* a, const void* b, size_t s) { + EXPECT_TRUE(InlinedMemcmpEq(a, b, s)) + << "a is: " << std::string((char*)a, s) << std::endl + << "b is: " << std::string((char*)b, s) << std::endl; +} + +void failed_helper(const void* a, const void* b, size_t s) { + EXPECT_FALSE(InlinedMemcmpEq(a, b, s)) + << "a is: " << std::string((char*)a, s) << std::endl + << "b is: " << std::string((char*)b, s) << std::endl; +} + +TEST(InlinedMemcmpEq, Basic) { + { + std::string str = random_string(1024); + for (size_t i = 1; i < 1024; ++i) { + auto a = std::unique_ptr(new char[i]); + auto b = std::unique_ptr(new char[i]); + std::memcpy(a.get(), str.data(), i); + std::memcpy(b.get(), str.data(), i); + success_helper(a.get(), b.get(), i); + for (size_t j = i - 1; j > 0; --j) { + a[j] = 'x'; + b[j] = 'y'; + failed_helper(a.get(), b.get(), j + 1); + } + } + } +} + +TEST(InlinedMemcmpEq, CrossPage) { + { + std::string str = random_string(1024); + for (size_t i = 1; i < 1024; ++i) { + auto a_ptr = std::unique_ptr( + static_cast(aligned_alloc(4096, 4096 * 2)), + [](char* ptr) { free(ptr); }); + auto b = std::unique_ptr(new char[i]); + char* a = a_ptr.get() + 4095; + std::memcpy(a, str.data(), i); + std::memcpy(b.get(), str.data(), i); + success_helper(a, b.get(), i); + for (size_t j = i - 1; j > 0; --j) { + a[j] = 'x'; + b[j] = 'y'; + failed_helper(a, b.get(), j + 1); + } + } + } +} + +} // namespace +#endif