From df2c410084b008d9436edac1044db91eef293c9b Mon Sep 17 00:00:00 2001 From: "Cecil (Willow OneVision)" Date: Thu, 21 May 2026 21:06:20 +0200 Subject: [PATCH] ggml: ARM NEON dequant kernel for turbo4 (vqtbl4q_u8 4-bit PolarQuant) First aarch64 NEON SIMD implementation of TurboQuant turbo4_0 dequantization. Reference implementations existed for Metal, CUDA, Vulkan and scalar C; this adds the ARM NEON path (ARMv8.0+ baseline, vqtbl4q_u8). Strategy: pre-scale the 16-entry CENTROIDS_4BIT * norm into a 64-byte LUT held in 4x uint8x16_t, then use vqtbl4q_u8 for SIMD nibble->fp32 lookup. Auto-enabled at compile time via __ARM_NEON + __aarch64__; disable for debug with -DGGML_TURBO_NEON_DISABLE. Validation: - Bit-exact: 10,000 random blocks x 128 elements = 1,280,000 fp32 values, 0 bit-mismatches vs scalar reference. IEEE 754 deterministic since pre-scaled LUT produces the same (centroid * norm) fp32 product. - Microbench Cortex-A76 (Raspberry Pi 5/16): 2.01x speedup over -O3 scalar, 3.00 -> 6.04 GB/s out, robust 1.89-2.14x across working sets 128 -> 65,536 blocks (8.7 KB -> 4.4 MB, spans L1/L2/DRAM). - End-to-end Pi16 llama-server (Gemma E4B + turbo4 KV): +1.9-3.3% tok/s on text generation (modest because dequant is small fraction of total inference cost; matches Amdahl ceiling). --- ggml/src/ggml-turbo-quant.c | 80 ++++++++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-turbo-quant.c b/ggml/src/ggml-turbo-quant.c index 8cbc83fd1d9b..78ca7439a60d 100644 --- a/ggml/src/ggml-turbo-quant.c +++ b/ggml/src/ggml-turbo-quant.c @@ -565,6 +565,73 @@ void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * G } } +/* ---------- Optional ARM NEON dequant for TURBO4 (4-bit PolarQuant) ---------- + * Enabled by default on aarch64 with NEON (ARMv8.0+). Bit-exact with the scalar + * implementation: pre-scales the 16-entry CENTROIDS_4BIT * norm into a 64-byte + * LUT held in 4x uint8x16_t, then uses vqtbl4q_u8 for SIMD nibble->fp32 lookup. + * Measured ~2x speedup on Cortex-A76 microbench, ~2-3% end-to-end tok/s on Pi16. + * Disable with -DGGML_TURBO_NEON_DISABLE if needed for debug or A/B. + */ +#if defined(__ARM_NEON) && defined(__aarch64__) && !defined(GGML_TURBO_NEON_DISABLE) +#define GGML_TURBO_NEON 1 +#include + +static const uint8_t turbo4_neon_bc_lo[16] = { 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3 }; +static const uint8_t turbo4_neon_bc_hi[16] = { 4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7 }; +static const uint8_t turbo4_neon_stride[16] = { 0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3 }; + +static inline void dequantize_row_turbo4_0_neon_4bit( + const block_turbo4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int nb, + const float * GGML_RESTRICT centroids_4bit) { + const uint8x16_t bc_lo = vld1q_u8(turbo4_neon_bc_lo); + const uint8x16_t bc_hi = vld1q_u8(turbo4_neon_bc_hi); + const uint8x16_t stride = vld1q_u8(turbo4_neon_stride); + const uint8x8_t mask_lo = vdup_n_u8(0x0F); + + const float32x4_t lut_raw0 = vld1q_f32(centroids_4bit + 0); + const float32x4_t lut_raw1 = vld1q_f32(centroids_4bit + 4); + const float32x4_t lut_raw2 = vld1q_f32(centroids_4bit + 8); + const float32x4_t lut_raw3 = vld1q_f32(centroids_4bit + 12); + + for (int block = 0; block < nb; block++) { + const float norm = GGML_FP16_TO_FP32(x[block].norm); + const float32x4_t vnorm = vdupq_n_f32(norm); + + const uint8x16x4_t lut = {{ + vreinterpretq_u8_f32(vmulq_f32(lut_raw0, vnorm)), + vreinterpretq_u8_f32(vmulq_f32(lut_raw1, vnorm)), + vreinterpretq_u8_f32(vmulq_f32(lut_raw2, vnorm)), + vreinterpretq_u8_f32(vmulq_f32(lut_raw3, vnorm)), + }}; + + float * dst = y + block * QK_TURBO4; + const uint8_t * qs = x[block].qs; + + for (int i = 0; i < QK_TURBO4; i += 8) { + uint32_t packed_raw; + memcpy(&packed_raw, qs + i / 2, 4); + uint8x8_t packed = vreinterpret_u8_u32(vdup_n_u32(packed_raw)); + + uint8x8_t lo = vand_u8(packed, mask_lo); + uint8x8_t hi = vshr_n_u8(packed, 4); + uint8x8x2_t zipped = vzip_u8(lo, hi); + uint8x8_t nibbles = zipped.val[0]; + uint8x8_t nib_x4 = vshl_n_u8(nibbles, 2); + uint8x16_t nib_x4_q = vcombine_u8(nib_x4, vdup_n_u8(0)); + + uint8x16_t tbl_idx_lo = vaddq_u8(vqtbl1q_u8(nib_x4_q, bc_lo), stride); + uint8x16_t tbl_idx_hi = vaddq_u8(vqtbl1q_u8(nib_x4_q, bc_hi), stride); + + uint8x16_t out_lo = vqtbl4q_u8(lut, tbl_idx_lo); + uint8x16_t out_hi = vqtbl4q_u8(lut, tbl_idx_hi); + + vst1q_u8((uint8_t *)(dst + i + 0), out_lo); + vst1q_u8((uint8_t *)(dst + i + 4), out_hi); + } + } +} +#endif /* aarch64 NEON */ + void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { turbo_init_rotation(); @@ -573,14 +640,17 @@ void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGM const int d = QK_TURBO4; #if TURBO4_USE_4BIT - /* 4-bit PolarQuant: nibble unpack → centroid → inverse rotate → scale */ - /* TODO: add proper 4-bit centroid table to C code (currently only in Metal) */ + /* 4-bit PolarQuant: nibble unpack -> centroid * norm (no inverse WHT, stays in rotated domain). */ static const float CENTROIDS_4BIT[16] = { -0.173926f, -0.117195f, -0.089527f, -0.068756f, -0.051262f, -0.035597f, -0.020989f, -0.006938f, 0.006938f, 0.020989f, 0.035597f, 0.051262f, 0.068756f, 0.089527f, 0.117195f, 0.173926f }; +#ifdef GGML_TURBO_NEON + dequantize_row_turbo4_0_neon_4bit(x, y, nb, CENTROIDS_4BIT); + (void)d; +#else for (int block = 0; block < nb; block++) { float norm = GGML_FP16_TO_FP32(x[block].norm); float * dst = y + block * d; @@ -589,10 +659,10 @@ void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGM dst[i] = CENTROIDS_4BIT[idx] * norm; } /* No inverse WHT, dequant stays in the rotated domain. - * Q is WHT-rotated by the graph, so gives correct attention scores. - * The inverse WHT is applied to the attention output via GGML_OP_TURBO_WHT (direction=1) in the graph. - */ + * Q is WHT-rotated by the graph; the inverse WHT is applied to attention output via GGML_OP_TURBO_WHT. + */ } +#endif /* GGML_TURBO_NEON */ #else /* Legacy 3-bit + QJL dequant */ turbo_init_qjl();