From 0d095427ff734686b09d9792b4ff0f6d755308f2 Mon Sep 17 00:00:00 2001 From: SXX Date: Tue, 3 Jun 2025 14:26:02 +0800 Subject: [PATCH] sgemm: simplify kernel_x86_avx logic and reduce shuffle overhead --- src/sgemm_kernel.rs | 217 ++++++++++++++++---------------------------- 1 file changed, 79 insertions(+), 138 deletions(-) diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index 28fe8ed..432c41d 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -325,102 +325,65 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T, let (mut a, mut b) = if prefer_row_major_c { (a, b) } else { (b, a) }; let (rsc, csc) = if prefer_row_major_c { (rsc, csc) } else { (csc, rsc) }; - macro_rules! shuffle_mask { - ($z:expr, $y:expr, $x:expr, $w:expr) => { - ($z << 6) | ($y << 4) | ($x << 2) | $w - } - } macro_rules! permute_mask { ($z:expr, $y:expr, $x:expr, $w:expr) => { ($z << 6) | ($y << 4) | ($x << 2) | $w } } - macro_rules! permute2f128_mask { - ($y:expr, $x:expr) => { - (($y << 4) | $x) - } - } - // Start data load before each iteration let mut av = _mm256_load_ps(a); - let mut bv = _mm256_load_ps(b); + let mut bvl = _mm256_broadcast_ps(&*(b.add(0) as *const _)); + let mut bvh = _mm256_broadcast_ps(&*(b.add(4) as *const _)); // Compute A B unroll_by_with_last!(4 => k, is_last, { // We compute abij = ai bj // - // Load b as one contiguous vector - // Load a as striped vectors // - // Shuffle the abij elements in order after the loop. - // - // Note this scheme copied and transposed from the BLIS 8x8 sgemm - // microkernel. - // - // Our a indices are striped and our b indices are linear. In - // the variable names below, we always have doubled indices so - // for example a0246 corresponds to a vector of a0 a0 a2 a2 a4 a4 a6 a6. - // - // ab0246: ab2064: ab4602: ab6420: - // ( ab00 ( ab20 ( ab40 ( ab60 - // ab01 ab21 ab41 ab61 - // ab22 ab02 ab62 ab42 - // ab23 ab03 ab63 ab43 - // ab44 ab64 ab04 ab24 - // ab45 ab65 ab05 ab25 - // ab66 ab46 ab26 ab06 - // ab67 ) ab47 ) ab27 ) ab07 ) + // ab0: ab1: ab2: ab3: + // ( ab00 ( ab10 ( ab20 ( ab30 + // ab11 ab21 ab31 ab01 + // ab22 ab32 ab02 ab12 + // ab33 ab03 ab13 ab23 + // ab40 ab50 ab60 ab70 + // ab51 ab61 ab71 ab41 + // ab62 ab72 ab42 ab52 + // ab73 ) ab43 ) ab53 ) ab63 ) // // ab1357: ab3175: ab5713: ab7531: - // ( ab10 ( ab30 ( ab50 ( ab70 - // ab11 ab31 ab51 ab71 - // ab32 ab12 ab72 ab52 - // ab33 ab13 ab73 ab53 - // ab54 ab74 ab14 ab34 - // ab55 ab75 ab15 ab35 - // ab76 ab56 ab36 ab16 - // ab77 ) ab57 ) ab37 ) ab17 ) - - const PERM32_2301: i32 = permute_mask!(1, 0, 3, 2); - const PERM128_30: i32 = permute2f128_mask!(0, 3); - - // _mm256_moveldup_ps(av): - // vmovsldup ymm2, ymmword ptr [rax] - // - // Load and duplicate each even word: - // ymm2 ← [a0 a0 a2 a2 a4 a4 a6 a6] - // - // _mm256_movehdup_ps(av): - // vmovshdup ymm2, ymmword ptr [rax] - // - // Load and duplicate each odd word: - // ymm2 ← [a1 a1 a3 a3 a5 a5 a7 a7] - // + // ( ab04 ( ab14 ( ab24 ( ab34 + // ab15 ab25 ab35 ab05 + // ab26 ab36 ab06 ab16 + // ab37 ab07 ab17 ab27 + // ab44 ab54 ab64 ab74 + // ab55 ab65 ab75 ab45 + // ab66 ab76 ab46 ab56 + // ab77 ) ab47 ) ab57 ) ab67 ) - let a0246 = _mm256_moveldup_ps(av); // Load: a0 a0 a2 a2 a4 a4 a6 a6 - let a2064 = _mm256_permute_ps(a0246, PERM32_2301); + let a01234567 = av; + let a12305674 = _mm256_permute_ps(av, permute_mask!(0, 3, 2, 1)); + let a23016745 = _mm256_permute_ps(av, permute_mask!(1, 0, 3, 2)); + let a30127456 = _mm256_permute_ps(av, permute_mask!(2, 1, 0, 3)); - let a1357 = _mm256_movehdup_ps(av); // Load: a1 a1 a3 a3 a5 a5 a7 a7 - let a3175 = _mm256_permute_ps(a1357, PERM32_2301); + ab[0] = MA::multiply_add(a01234567, bvl, ab[0]); + ab[4] = MA::multiply_add(a01234567, bvh, ab[4]); - let bv_lh = _mm256_permute2f128_ps(bv, bv, PERM128_30); + ab[1] = MA::multiply_add(a12305674, bvl, ab[1]); + ab[5] = MA::multiply_add(a12305674, bvh, ab[5]); - ab[0] = MA::multiply_add(a0246, bv, ab[0]); - ab[1] = MA::multiply_add(a2064, bv, ab[1]); - ab[2] = MA::multiply_add(a0246, bv_lh, ab[2]); - ab[3] = MA::multiply_add(a2064, bv_lh, ab[3]); + ab[2] = MA::multiply_add(a23016745, bvl, ab[2]); + ab[6] = MA::multiply_add(a23016745, bvh, ab[6]); - ab[4] = MA::multiply_add(a1357, bv, ab[4]); - ab[5] = MA::multiply_add(a3175, bv, ab[5]); - ab[6] = MA::multiply_add(a1357, bv_lh, ab[6]); - ab[7] = MA::multiply_add(a3175, bv_lh, ab[7]); + ab[3] = MA::multiply_add(a30127456, bvl, ab[3]); + ab[7] = MA::multiply_add(a30127456, bvh, ab[7]); if !is_last { a = a.add(MR); b = b.add(NR); - bv = _mm256_load_ps(b); + bvl = _mm256_broadcast_ps(&*(b.add(0) as *const _)); + bvh = _mm256_broadcast_ps(&*(b.add(4) as *const _)); av = _mm256_load_ps(a); } }); @@ -428,74 +391,52 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T, let alphav = _mm256_set1_ps(alpha); // Permute to put the abij elements in order - // - // shufps 0xe4: 22006644 00224466 -> 22226666 - // - // vperm2 0x30: 00004444 44440000 -> 00000000 - // vperm2 0x12: 00004444 44440000 -> 44444444 - // - - let ab0246 = ab[0]; - let ab2064 = ab[1]; - let ab4602 = ab[2]; // reverse order - let ab6420 = ab[3]; // reverse order - - let ab1357 = ab[4]; - let ab3175 = ab[5]; - let ab5713 = ab[6]; // reverse order - let ab7531 = ab[7]; // reverse order - - const SHUF_0123: i32 = shuffle_mask!(3, 2, 1, 0); - debug_assert_eq!(SHUF_0123, 0xE4); - - const PERM128_02: i32 = permute2f128_mask!(2, 0); - const PERM128_31: i32 = permute2f128_mask!(1, 3); - - // No elements are "shuffled" in truth, they all stay at their index - // but we combine vectors to de-stripe them. - // - // For example, the first shuffle below uses 0 1 2 3 which - // corresponds to the X0 X1 Y2 Y3 sequence etc: - // - // variable - // X ab00 ab01 ab22 ab23 ab44 ab45 ab66 ab67 ab0246 - // Y ab20 ab21 ab02 ab03 ab64 ab65 ab46 ab47 ab2064 - // - // X0 X1 Y2 Y3 X4 X5 Y6 Y7 - // = ab00 ab01 ab02 ab03 ab44 ab45 ab46 ab47 ab0044 - - let ab0044 = _mm256_shuffle_ps(ab0246, ab2064, SHUF_0123); - let ab2266 = _mm256_shuffle_ps(ab2064, ab0246, SHUF_0123); - - let ab4400 = _mm256_shuffle_ps(ab4602, ab6420, SHUF_0123); - let ab6622 = _mm256_shuffle_ps(ab6420, ab4602, SHUF_0123); - - let ab1155 = _mm256_shuffle_ps(ab1357, ab3175, SHUF_0123); - let ab3377 = _mm256_shuffle_ps(ab3175, ab1357, SHUF_0123); - - let ab5511 = _mm256_shuffle_ps(ab5713, ab7531, SHUF_0123); - let ab7733 = _mm256_shuffle_ps(ab7531, ab5713, SHUF_0123); - - let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_02); - let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_31); - - let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_02); - let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_31); - - let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_02); - let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_31); - - let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_02); - let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_31); - - ab[0] = ab0000; - ab[1] = ab1111; - ab[2] = ab2222; - ab[3] = ab3333; - ab[4] = ab4444; - ab[5] = ab5555; - ab[6] = ab6666; - ab[7] = ab7777; + let t0 = ab[0]; + let t1 = ab[1]; + let t2 = ab[2]; + let t3 = ab[3]; + + let (t0, t1, t2, t3) = ( + _mm256_blend_ps(t0, t3, 0b10101010), + _mm256_blend_ps(t1, t0, 0b10101010), + _mm256_blend_ps(t2, t1, 0b10101010), + _mm256_blend_ps(t3, t2, 0b10101010), + ); + + let (t0, t1, t2, t3) = ( + _mm256_blend_ps(t0, t2, 0b11001100), + _mm256_blend_ps(t1, t3, 0b11001100), + _mm256_blend_ps(t2, t0, 0b11001100), + _mm256_blend_ps(t3, t1, 0b11001100), + ); + + let t4 = ab[4]; + let t5 = ab[5]; + let t6 = ab[6]; + let t7 = ab[7]; + + let (t4, t5, t6, t7) = ( + _mm256_blend_ps(t4, t7, 0b10101010), + _mm256_blend_ps(t5, t4, 0b10101010), + _mm256_blend_ps(t6, t5, 0b10101010), + _mm256_blend_ps(t7, t6, 0b10101010), + ); + + let (t4, t5, t6, t7) = ( + _mm256_blend_ps(t4, t6, 0b11001100), + _mm256_blend_ps(t5, t7, 0b11001100), + _mm256_blend_ps(t6, t4, 0b11001100), + _mm256_blend_ps(t7, t5, 0b11001100), + ); + + ab[0] = _mm256_permute2f128_ps(t0, t4, 0x20); + ab[1] = _mm256_permute2f128_ps(t1, t5, 0x20); + ab[2] = _mm256_permute2f128_ps(t2, t6, 0x20); + ab[3] = _mm256_permute2f128_ps(t3, t7, 0x20); + ab[4] = _mm256_permute2f128_ps(t0, t4, 0x31); + ab[5] = _mm256_permute2f128_ps(t1, t5, 0x31); + ab[6] = _mm256_permute2f128_ps(t2, t6, 0x31); + ab[7] = _mm256_permute2f128_ps(t3, t7, 0x31); // Compute α (A B) // Compute here if we don't have fma, else pick up α further down