Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 75 additions & 5 deletions ggml/src/ggml-turbo-quant.c
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,73 @@ void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * G
}
}

/* ---------- Optional ARM NEON dequant for TURBO4 (4-bit PolarQuant) ----------
* Enabled by default on aarch64 with NEON (ARMv8.0+). Bit-exact with the scalar
* implementation: pre-scales the 16-entry CENTROIDS_4BIT * norm into a 64-byte
* LUT held in 4x uint8x16_t, then uses vqtbl4q_u8 for SIMD nibble->fp32 lookup.
* Measured ~2x speedup on Cortex-A76 microbench, ~2-3% end-to-end tok/s on Pi16.
* Disable with -DGGML_TURBO_NEON_DISABLE if needed for debug or A/B.
*/
#if defined(__ARM_NEON) && defined(__aarch64__) && !defined(GGML_TURBO_NEON_DISABLE)
#define GGML_TURBO_NEON 1
#include <arm_neon.h>

static const uint8_t turbo4_neon_bc_lo[16] = { 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3 };
static const uint8_t turbo4_neon_bc_hi[16] = { 4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7 };
static const uint8_t turbo4_neon_stride[16] = { 0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3 };

static inline void dequantize_row_turbo4_0_neon_4bit(
const block_turbo4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int nb,
const float * GGML_RESTRICT centroids_4bit) {
const uint8x16_t bc_lo = vld1q_u8(turbo4_neon_bc_lo);
const uint8x16_t bc_hi = vld1q_u8(turbo4_neon_bc_hi);
const uint8x16_t stride = vld1q_u8(turbo4_neon_stride);
const uint8x8_t mask_lo = vdup_n_u8(0x0F);

const float32x4_t lut_raw0 = vld1q_f32(centroids_4bit + 0);
const float32x4_t lut_raw1 = vld1q_f32(centroids_4bit + 4);
const float32x4_t lut_raw2 = vld1q_f32(centroids_4bit + 8);
const float32x4_t lut_raw3 = vld1q_f32(centroids_4bit + 12);

for (int block = 0; block < nb; block++) {
const float norm = GGML_FP16_TO_FP32(x[block].norm);
const float32x4_t vnorm = vdupq_n_f32(norm);

const uint8x16x4_t lut = {{
vreinterpretq_u8_f32(vmulq_f32(lut_raw0, vnorm)),
vreinterpretq_u8_f32(vmulq_f32(lut_raw1, vnorm)),
vreinterpretq_u8_f32(vmulq_f32(lut_raw2, vnorm)),
vreinterpretq_u8_f32(vmulq_f32(lut_raw3, vnorm)),
}};

float * dst = y + block * QK_TURBO4;
const uint8_t * qs = x[block].qs;

for (int i = 0; i < QK_TURBO4; i += 8) {
uint32_t packed_raw;
memcpy(&packed_raw, qs + i / 2, 4);
uint8x8_t packed = vreinterpret_u8_u32(vdup_n_u32(packed_raw));

uint8x8_t lo = vand_u8(packed, mask_lo);
uint8x8_t hi = vshr_n_u8(packed, 4);
uint8x8x2_t zipped = vzip_u8(lo, hi);
uint8x8_t nibbles = zipped.val[0];
uint8x8_t nib_x4 = vshl_n_u8(nibbles, 2);
uint8x16_t nib_x4_q = vcombine_u8(nib_x4, vdup_n_u8(0));

uint8x16_t tbl_idx_lo = vaddq_u8(vqtbl1q_u8(nib_x4_q, bc_lo), stride);
uint8x16_t tbl_idx_hi = vaddq_u8(vqtbl1q_u8(nib_x4_q, bc_hi), stride);

uint8x16_t out_lo = vqtbl4q_u8(lut, tbl_idx_lo);
uint8x16_t out_hi = vqtbl4q_u8(lut, tbl_idx_hi);

vst1q_u8((uint8_t *)(dst + i + 0), out_lo);
vst1q_u8((uint8_t *)(dst + i + 4), out_hi);
}
}
}
#endif /* aarch64 NEON */

void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
turbo_init_rotation();

Expand All @@ -573,14 +640,17 @@ void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGM
const int d = QK_TURBO4;

#if TURBO4_USE_4BIT
/* 4-bit PolarQuant: nibble unpack → centroid → inverse rotate → scale */
/* TODO: add proper 4-bit centroid table to C code (currently only in Metal) */
/* 4-bit PolarQuant: nibble unpack -> centroid * norm (no inverse WHT, stays in rotated domain). */
static const float CENTROIDS_4BIT[16] = {
-0.173926f, -0.117195f, -0.089527f, -0.068756f,
-0.051262f, -0.035597f, -0.020989f, -0.006938f,
0.006938f, 0.020989f, 0.035597f, 0.051262f,
0.068756f, 0.089527f, 0.117195f, 0.173926f
};
#ifdef GGML_TURBO_NEON
dequantize_row_turbo4_0_neon_4bit(x, y, nb, CENTROIDS_4BIT);
(void)d;
#else
for (int block = 0; block < nb; block++) {
float norm = GGML_FP16_TO_FP32(x[block].norm);
float * dst = y + block * d;
Expand All @@ -589,10 +659,10 @@ void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGM
dst[i] = CENTROIDS_4BIT[idx] * norm;
}
/* No inverse WHT, dequant stays in the rotated domain.
* Q is WHT-rotated by the graph, so <Q_rot, K_rot> gives correct attention scores.
* The inverse WHT is applied to the attention output via GGML_OP_TURBO_WHT (direction=1) in the graph.
*/
* Q is WHT-rotated by the graph; the inverse WHT is applied to attention output via GGML_OP_TURBO_WHT.
*/
}
#endif /* GGML_TURBO_NEON */
#else
/* Legacy 3-bit + QJL dequant */
turbo_init_qjl();
Expand Down