diff --git a/core/src/ops/cnn/conv/conv.rs b/core/src/ops/cnn/conv/conv.rs index c6d5908b1f..ced2354b36 100644 --- a/core/src/ops/cnn/conv/conv.rs +++ b/core/src/ops/cnn/conv/conv.rs @@ -33,7 +33,7 @@ use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec}; use crate::ops::nn::{BaseDataShape, DataFormat, DataShape}; use tract_linalg::mmm::{MMMInputFormat, MatMatMul}; -use tract_linalg::pack::PackedFormat; +use tract_linalg::pack::{PackedFormat, PackedI8K4}; #[derive(Debug, Clone, new, Hash, PartialEq, Eq)] pub struct Conv { @@ -123,13 +123,11 @@ impl Conv { &[kernel], )? } else { - let format = format - .downcast_ref::() - .context("Expect regular packing for numeric weights")?; + // PackedFormat or a custom numeric packer (e.g. PackedI8K4). model.wire_node( format!("{name}.prep_kernel.pack"), OptMatMulPack { - packers: vec![format.clone()], + packers: vec![dyn_clone::clone_box(format)], k_axis: 2, mn_axis: 1, mode_picker: ModePicker::Single, @@ -240,7 +238,11 @@ impl Conv { &sum_ker_n_g_c, )?; - ensure!(mmm.packings()[packing].1.downcast_ref::().is_some()); + ensure!( + mmm.packings()[packing].1.downcast_ref::().is_some() + || mmm.packings()[packing].1.downcast_ref::().is_some(), + "Im2Col/QSumB support PackedFormat or PackedI8K4 activation packings" + ); let mut sum_x = model.wire_node( format!("{name}.sum_x"), super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k }, diff --git a/core/src/ops/cnn/conv/im2col.rs b/core/src/ops/cnn/conv/im2col.rs index ac67c210d8..6cec51ba95 100644 --- a/core/src/ops/cnn/conv/im2col.rs +++ b/core/src/ops/cnn/conv/im2col.rs @@ -1,7 +1,8 @@ use tract_linalg::mmm::{ - EagerPackedInput, MMMInputValue, MatMatMul, PackedExoticFact, PackedMatrixStorage, + EagerPackedInput, MMMInputFormat, MMMInputValue, MatMatMul, PackedExoticFact, + PackedMatrixStorage, }; -use tract_linalg::pack::{PackedFormat, PackingWriter}; +use tract_linalg::pack::{PackedFormat, PackedI8K4, PackingWriter}; use crate::internal::*; use ndarray::prelude::*; @@ -23,7 +24,8 @@ struct SymbolicGeometry { group: usize, pool_spec: PoolSpec, pool_geometry: PoolGeometry, - b_pack: PackedFormat, + // The kernel's activation packing: PackedFormat (K-major) or PackedI8K4 (K=4-inner). + out_format: Box, k: usize, } @@ -32,7 +34,7 @@ struct ConcreteGeometry { pool: ConcretePoolGeometry, pub n: usize, k: usize, - pub b_pack: PackedFormat, + pub out_format: Box, pub ci_per_group: usize, patcher: Patcher, input_shape_with_n: DataShape, @@ -40,10 +42,10 @@ struct ConcreteGeometry { } impl GeometryBound { - pub fn b_pack(&self) -> &PackedFormat { + pub fn out_format(&self) -> &dyn MMMInputFormat { match self { - GeometryBound::Symbolic(s) => &s.b_pack, - GeometryBound::Concrete(s) => &s.b_pack, + GeometryBound::Symbolic(s) => &*s.out_format, + GeometryBound::Concrete(s) => &*s.out_format, } } pub fn k(&self) -> usize { @@ -88,7 +90,7 @@ impl ResolveTo for SymbolicGeometry { n, k: self.k, ci_per_group, - b_pack: self.b_pack.clone(), + out_format: self.out_format.clone(), patcher, input_shape_with_n, packed_shape, @@ -105,15 +107,10 @@ impl Im2Col { mmm: Box, packing: usize, ) -> TractResult { - let b_pack = mmm.packings()[packing] - .1 - .downcast_ref::() - .context("Im2Col expects regular packed format")? - .clone(); - + let out_format = dyn_clone::clone_box(&*mmm.packings()[packing].1); let pool_geometry = pool_spec.compute_geo(input_full_shape)?; let geometry: GeometryBound<_, _> = - SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, b_pack, k } + SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, out_format, k } .into(); let geometry = geometry.optimize_if(input_full_shape.as_concrete())?; Ok(Im2Col { pool_spec, group, geometry }) @@ -156,8 +153,21 @@ impl EvalOp for Im2Col { if !self.pool_spec.data_format.has_n() { input.insert_axis(0)?; } - let panel_bytes = - geometry.b_pack.single_panel_len(geometry.k) * input.datum_type().size_of(); + let dt = input.datum_type(); + let r = geometry.out_format.r(); + // Buffer geometry. zero_init for PackedI8K4: the K=4-inner writer skips + // the K-padding lanes (k..k_aligned), which SMOPA accumulates — they must + // be 0. PackedFormat has no K padding; its mn-padding maps to discarded + // output rows, so uninitialized is fine (matches prior behaviour). + let (single_panel_len, buf_align, zero_init) = + if let Some(pf) = geometry.out_format.downcast_ref::() { + (pf.single_panel_len(geometry.k), pf.alignment(), false) + } else if let Some(p4) = geometry.out_format.downcast_ref::() { + (p4.single_panel_len(geometry.k), p4.alignment(), true) + } else { + bail!("Im2Col: unsupported packing format {:?}", geometry.out_format) + }; + let panel_bytes = single_panel_len * dt.size_of(); let n_batches = *geometry.input_shape_with_n.n().unwrap_or(&1); let n_groups = self.group; @@ -169,12 +179,15 @@ impl EvalOp for Im2Col { let n = if geometry.pool.output_shape.shape.contains(&0) { 0 } else { geometry.n }; let mut data = Tensor::uninitialized_aligned_dt( - input.datum_type(), - &[geometry.b_pack.len(geometry.k, n)], - geometry.b_pack.alignment(), + dt, + &[n.divceil(r) * single_panel_len], + buf_align, )?; + if zero_init { + data.as_bytes_mut().fill(0); + } if n > 0 { - dispatch_copy_by_size!(Patcher::patch(input.datum_type())( + dispatch_copy_by_size!(Patcher::patch(dt)( &geometry.patcher, &geometry, &input, @@ -185,7 +198,7 @@ impl EvalOp for Im2Col { } values.push(Box::new(EagerPackedInput { fact: PackedExoticFact { - format: Box::new(geometry.b_pack.clone()), + format: geometry.out_format.clone(), k: geometry.k, mn: n.to_dim(), }, @@ -211,7 +224,7 @@ impl TypedOp for Im2Col { let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?; let mn = output_shape.hw_dims().iter().product::(); let pof = PackedExoticFact { - format: Box::new(self.geometry.b_pack().clone()), + format: dyn_clone::clone_box(self.geometry.out_format()), k: self.geometry.k(), mn, }; @@ -259,34 +272,57 @@ impl Patcher { pack: &'p mut TensorView, g: usize, pad_value: Option<&Tensor>, + ) -> TractResult<()> { + // Pick the packing writer for the kernel's output format, then run the + // (writer-generic) patcher. PackedFormat keeps the K-major fast path; + // PackedI8K4 writes the SMOPA K=4-inner layout in the same single pass. + let ptr = unsafe { pack.as_slice_mut_unchecked::().as_mut_ptr() }; + if let Some(pf) = geo.out_format.downcast_ref::() { + let mut w = pf.write_with_k_outer(ptr, geo.k, geo.n); + self.run::(geo, input, g, pad_value, &mut w) + } else if let Some(p4) = geo.out_format.downcast_ref::() { + let mut w = p4.write_with_k_outer(ptr, geo.k, geo.n); + self.run::(geo, input, g, pad_value, &mut w) + } else { + bail!("Im2Col: unsupported packing format {:?}", geo.out_format) + } + } + + fn run>( + &self, + geo: &ConcreteGeometry, + input: &TensorView, + g: usize, + pad_value: Option<&Tensor>, + writer: &mut W, ) -> TractResult<()> { match self { - Patcher::Valid1d => Self::valid_1d::(geo, input, pack, g), - Patcher::Valid2d => Self::valid_2d::(geo, input, pack, g), - Patcher::Padded2d => Self::padded_2d::( + Patcher::Valid1d => Self::valid_1d::(geo, input, g, writer), + Patcher::Valid2d => Self::valid_2d::(geo, input, g, writer), + Patcher::Padded2d => Self::padded_2d::( geo, input, - pack, g, pad_value.unwrap_or(&Tensor::zero_scalar::()?), + writer, ), - _ => Self::generic::( + _ => Self::generic::( geo, input, - pack, g, pad_value.unwrap_or(&Tensor::zero_scalar::()?), + writer, ), } } #[inline(never)] - fn generic<'p, T: Copy + Datum>( - geometry: &'p ConcreteGeometry, + fn generic>( + geometry: &ConcreteGeometry, input: &TensorView, - pack: &'p mut TensorView, g: usize, pad_value: &Tensor, + writer: &mut W, ) -> TractResult<()> { unsafe { let pad_value = *pad_value.to_scalar_unchecked(); @@ -307,25 +343,28 @@ impl Patcher { } } } - geometry.b_pack.pack(pack, mega_matrix.view(), 0, 1); + // mega_matrix is [k, n] (k-major); feed K-outer to the writer, which + // lays out the kernel's packing (K-major for PackedFormat, K=4-inner + // for PackedI8K4) — byte-identical to PackedFormat::pack for the former. + let mv = mega_matrix.as_slice_unchecked::(); + for kk in 0..geometry.k { + writer.write_slice(&mv[kk * geometry.n..(kk + 1) * geometry.n]); + } Ok(()) } } #[inline(never)] - fn valid_1d<'p, T: Copy + Datum>( - geometry: &'p ConcreteGeometry, + fn valid_1d>( + geometry: &ConcreteGeometry, input: &TensorView, - pack: &'p mut TensorView, g: usize, + writer: &mut W, ) -> TractResult<()> { unsafe { let x_stride = *geometry.input_shape_with_n.h_stride() as isize * geometry.pool.patch.spec.strides[0] as isize; let c_stride = *geometry.input_shape_with_n.c_stride() as isize; - let pack = pack.as_slice_mut_unchecked::(); - let mut writer = - geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n); let iptr = input.as_ptr_unchecked::(); let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride()); let output_x = *geometry.pool.patch.output_shape.get_unchecked(0); @@ -356,16 +395,15 @@ impl Patcher { } #[inline(never)] - fn padded_2d<'p, T: Copy + Datum>( - geometry: &'p ConcreteGeometry, + fn padded_2d>( + geometry: &ConcreteGeometry, input: &TensorView, - pack: &'p mut TensorView, g: usize, pad_value: &Tensor, + writer: &mut W, ) -> TractResult<()> { unsafe { let pad_value = *pad_value.to_scalar_unchecked(); - let pack = pack.as_slice_mut_unchecked::(); let y_stride = geometry.pool.patch.spec.strides[0] as isize; let x_stride = geometry.pool.patch.spec.strides[1] as isize; let shape = &geometry.input_shape_with_n; @@ -375,8 +413,6 @@ impl Patcher { let input_heigth = shape.hw_dims()[0] as isize; let input_width = shape.hw_dims()[1] as isize; let kernel_len = geometry.pool.patch.standard_layout_data_field.len(); - let mut writer = - geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n); let iptr = input.as_ptr_unchecked::(); let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride()); let output_width = *geometry.pool.patch.output_shape.get_unchecked(1); @@ -402,22 +438,22 @@ impl Patcher { Self::padded_2d_invalid_x_loop( valid_x_start as usize, pad_value, - &mut writer, + &mut *writer, ); Self::padded_2d_valid_x_loop( valid_x_start, valid_x_end, x_stride_ptr, iptr, - &mut writer, + &mut *writer, ); Self::padded_2d_invalid_x_loop( output_width - valid_x_end as usize, pad_value, - &mut writer, + &mut *writer, ); } else { - Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut writer); + Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut *writer); } } } @@ -427,10 +463,10 @@ impl Patcher { } #[inline(never)] - unsafe fn padded_2d_invalid_x_loop( + unsafe fn padded_2d_invalid_x_loop>( count: usize, pad_value: T, - writer: &mut tract_linalg::pack::KOutWriter, + writer: &mut W, ) { for _ in 0..count { writer.write(pad_value); @@ -438,12 +474,12 @@ impl Patcher { } #[inline(never)] - unsafe fn padded_2d_valid_x_loop( + unsafe fn padded_2d_valid_x_loop>( x_min: isize, x_max: isize, x_stride_ptr: isize, iptr: *const T, - writer: &mut tract_linalg::pack::KOutWriter, + writer: &mut W, ) { // Fast path: x_stride_ptr == 1 means consecutive x values are at // consecutive memory addresses, so the inner loop is a contiguous @@ -461,22 +497,19 @@ impl Patcher { } #[inline(never)] - fn valid_2d<'p, T: Copy + Datum>( - geometry: &'p ConcreteGeometry, + fn valid_2d>( + geometry: &ConcreteGeometry, input: &TensorView, - pack: &'p mut TensorView, g: usize, + writer: &mut W, ) -> TractResult<()> { unsafe { - let pack = pack.as_slice_mut_unchecked::(); let shape = &geometry.input_shape_with_n; let y_stride = geometry.pool.patch.spec.strides[0] as isize; let x_stride = geometry.pool.patch.spec.strides[1] as isize; let y_stride_ptr = y_stride * *shape.h_stride() as isize; let x_stride_ptr = x_stride * *shape.w_stride() as isize; let c_stride_ptr = *shape.c_stride() as isize; - let mut writer = - geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n); let iptr = input.as_ptr_unchecked::(); let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride()); let output_y = *geometry.pool.patch.output_shape.get_unchecked(0); diff --git a/core/src/ops/cnn/conv/lazy_im2col.rs b/core/src/ops/cnn/conv/lazy_im2col.rs index 49254a59ae..6bb5d9ff46 100644 --- a/core/src/ops/cnn/conv/lazy_im2col.rs +++ b/core/src/ops/cnn/conv/lazy_im2col.rs @@ -124,7 +124,7 @@ impl TypedOp for LazyIm2Col { let exotic_fact = DynPackedExoticFact { k: self.params.k_byte_offsets.len().to_dim(), mn: self.params.n_byte_offsets.len().to_dim(), - packers: vec![self.params.packer.clone()], + packers: vec![Box::new(self.params.packer.clone()) as Box], }; Ok(tvec!(inputs[0].datum_type.fact([1, self.group]).with_exotic_fact(exotic_fact))) } diff --git a/core/src/ops/cnn/conv/q_sum_b.rs b/core/src/ops/cnn/conv/q_sum_b.rs index 434abbdf70..53f04aefcb 100644 --- a/core/src/ops/cnn/conv/q_sum_b.rs +++ b/core/src/ops/cnn/conv/q_sum_b.rs @@ -1,5 +1,6 @@ use crate::internal::*; use tract_linalg::mmm::{MMMInputValue, PackedMatrixStorage}; +use tract_linalg::pack::PackedI8K4; use tract_ndarray::prelude::*; #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -80,14 +81,27 @@ impl QSumB { output: &mut [i32], ) -> TractResult<()> { let (r, k, n) = (input.format().r(), input.k(), input.mn()); + // PackedI8K4 is K=4-inner: element (ik, ir) at (ik/4)*r*4 + ir*4 + ik%4, + // and the panel is k padded up to a multiple of 4. PackedFormat is K-major. + let is_k4 = input.format().downcast_ref::().is_some(); + let panel_len = if is_k4 { r * k.div_ceil(4) * 4 } else { r * k }; let panels = n.divceil(r); for ipanel in 0..panels { let panel = input.panel_bytes(ipanel, None)?; - let panel: &[T] = unsafe { std::slice::from_raw_parts(panel as *const T, r * k) }; + let panel: &[T] = unsafe { std::slice::from_raw_parts(panel as *const T, panel_len) }; let mut vec = vec![0i32; r]; - for ik in 0..k { - for ir in 0..r { - vec[ir] += panel[ik * r + ir].as_(); + if is_k4 { + for ik in 0..k { + let kbase = (ik / 4) * r * 4 + ik % 4; + for ir in 0..r { + vec[ir] += panel[kbase + ir * 4].as_(); + } + } + } else { + for ik in 0..k { + for ir in 0..r { + vec[ir] += panel[ik * r + ir].as_(); + } } } let len = r.min(n - r * ipanel); diff --git a/core/src/ops/einsum/einsum_matmul.rs b/core/src/ops/einsum/einsum_matmul.rs index 32df705c35..cf497af0c1 100644 --- a/core/src/ops/einsum/einsum_matmul.rs +++ b/core/src/ops/einsum/einsum_matmul.rs @@ -3,7 +3,6 @@ use std::ops::Deref; use tract_itertools::{izip, multiunzip}; use tract_linalg::block_quant::PackedBlockQuantFormat; -use tract_linalg::pack::PackedFormat; use super::*; use crate::ops::cast::cast; @@ -864,23 +863,21 @@ fn optimized_mat_mul( let name = &node.name; let pack_a: Box = if input_facts[0].konst.is_some() { - if let Some(pf) = left_pack.downcast_ref::() { - Box::new(OptMatMulPack { - packers: vec![pf.clone()], - mode_picker: ModePicker::Single, - k_axis: op.a_k(), - mn_axis: op.a_m(), - }) - } else if let Some(packed_format) = - left_pack.downcast_ref::().cloned() - { + if let Some(packed_format) = left_pack.downcast_ref::().cloned() { Box::new(OptSimpleMatMulPack { packed_format, k: input_shapes[0][op.a_k()].to_usize().unwrap(), m: input_shapes[0][op.a_m()].to_usize().unwrap(), }) } else { - bail!("Unexpected static input format {left_pack:?}"); + // PackedFormat or a custom packer (e.g. PackedI8K4); OptMatMulPack + // dispatches on the concrete format at pack time. + Box::new(OptMatMulPack { + packers: vec![left_pack], + mode_picker: ModePicker::Single, + k_axis: op.a_k(), + mn_axis: op.a_m(), + }) } } else { Box::new(OptMatMulPack { @@ -888,11 +885,8 @@ fn optimized_mat_mul( .iter() .map(|(mmm, p, pe)| { pe.as_ref() - .map(|pe| &pe.from) - .unwrap_or(&mmm.packings()[*p].0) - .downcast_ref::() - .unwrap() - .clone() + .map(|pe| pe.from.clone()) + .unwrap_or_else(|| mmm.packings()[*p].0.clone()) }) .collect(), mode_picker: mode_picker.clone(), @@ -907,12 +901,7 @@ fn optimized_mat_mul( OptMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), - packers: impls - .iter() - .map(|(mmm, p, _)| { - mmm.packings()[*p].1.downcast_ref::().unwrap().clone() - }) - .collect(), + packers: impls.iter().map(|(mmm, p, _)| mmm.packings()[*p].1.clone()).collect(), mode_picker: mode_picker.clone(), }, &[taps[1]], diff --git a/core/src/ops/einsum/kernel_selection.rs b/core/src/ops/einsum/kernel_selection.rs index e2a7ed988a..aa00942b18 100644 --- a/core/src/ops/einsum/kernel_selection.rs +++ b/core/src/ops/einsum/kernel_selection.rs @@ -53,8 +53,14 @@ pub fn strategize(model: &TypedModel, node: &TypedNode, op: &EinSumMatMul) -> Tr return Ok(single_strat(it)); } if op.n.as_i64().is_some_and(|n| n > 1) { - let it = - impls.into_iter().max_by_key(|(m, _, pe)| (pe.is_none(), m.nr() * m.mr())).unwrap(); + // For a 2D matmul (n > 1) a GEMV kernel (nr == 1) is a poor fit: it + // processes one output column per pass. Demote nr == 1 so it never wins + // the `nr * mr` tie against a square tile (e.g. i8 64x1 vs 8x8 both have + // nr*mr == 64). Ordering among nr > 1 kernels is left untouched. + let it = impls + .into_iter() + .max_by_key(|(m, _, pe)| (pe.is_none(), m.nr() > 1, m.nr() * m.mr())) + .unwrap(); return Ok(single_strat(it)); } let mut grouped_by_left_packing = Vec::<(&dyn MMMInputFormat, Vec<_>)>::new(); @@ -79,7 +85,21 @@ pub fn strategize(model: &TypedModel, node: &TypedNode, op: &EinSumMatMul) -> Tr (p, best_for_mmv, best_for_mmm) }) .max_by_key(|(_, mmv, mmm)| { - (mmv.0.nr() == 1 && mmm.0.nr() > 1, mmv.2.is_none(), mmm.0.mr(), mmm.0.nr()) + // When no group offers the ideal (true GEMV nr==1 + true matrix nr>1) + // pair, still prefer a group whose matrix-role kernel is a real matrix + // (nr > 1) over a GEMV-only group. Without this, int8 — whose GEMV + // (64x1), SMLAL (8x8) and SDOT (8x8_dot) kernels each use a different + // packing, so no single group is ideal — falls through to `mmm.mr` and + // picks the 64x1 GEMV even for symbolic (dynamic) n. f32/f16/block-quant + // are unaffected: they have a packing group that IS ideal (e.g. f32 + // 32x1/32x3, q40 32x1/32x3), so the first key already decides. + ( + mmv.0.nr() == 1 && mmm.0.nr() > 1, + mmv.2.is_none(), + mmm.0.nr() > 1, + mmm.0.mr(), + mmm.0.nr(), + ) }) .unwrap(); @@ -117,7 +137,11 @@ pub fn list_impls( .mmm_impls() .iter() .filter(|mmm| { - op.acceptable_accumulators().contains(&mmm.internal_type()) + // Only consider kernels runnable on this CPU: e.g. the SDOT i8 kernel + // carries a FEAT_DotProd platform predicate, and must not be selected on + // a CPU that would trap on the instruction. + mmm.is_supported_here() + && op.acceptable_accumulators().contains(&mmm.internal_type()) && mmm.stores().contains(&op.operating_dt.unquantized()) }) .flat_map(move |mmm| { diff --git a/core/src/ops/matmul/pack.rs b/core/src/ops/matmul/pack.rs index 41ecb42c81..8a3bcc93ae 100644 --- a/core/src/ops/matmul/pack.rs +++ b/core/src/ops/matmul/pack.rs @@ -1,17 +1,36 @@ use crate::axes::Axis; use crate::internal::*; use ndarray::*; +use tract_linalg::WeightType; use tract_linalg::block_quant::{ BlockQuantStorage, PackedBlockQuantFact, PackedBlockQuantFormat, block_quant_slice, }; -use tract_linalg::mmm::{MMMInputValue, PackedMatrixStorage}; -use tract_linalg::pack::PackedFormat; +use tract_linalg::mmm::{MMMInputFormat, MMMInputValue, PackedMatrixStorage}; +use tract_linalg::pack::{PackedFormat, PackedI8K4}; use super::ModePicker; +// Pack one (possibly strided) view with a dynamic packing format. Keeps the +// PackedFormat fast path byte-identical; routes the K=4-inner SMOPA packer +// (PackedI8K4) through its view packer. Other formats are unsupported here. +fn pack_view_with( + packer: &dyn MMMInputFormat, + t: &TensorView, + k_axis: usize, + mn_axis: usize, +) -> TractResult> { + if let Some(pf) = packer.downcast_ref::() { + pf.pack_tensor_view(t, k_axis, mn_axis) + } else if let Some(p4) = packer.downcast_ref::() { + p4.pack_view(t, k_axis, mn_axis) + } else { + bail!("OptMatMulPack does not support packing format {packer:?}") + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct OptMatMulPack { - pub(crate) packers: Vec, + pub(crate) packers: Vec>, pub(crate) mode_picker: ModePicker, pub(crate) k_axis: usize, pub(crate) mn_axis: usize, @@ -88,7 +107,7 @@ impl OptMatMulPack { let packer = &self.packers[mode]; let output_shape: TVec = self.output_shape(input.shape()); let stores = if output_shape.iter().all(|d| *d == 1) { - let packed = packer.pack_tensor_view(&input.view(), self.k_axis, self.mn_axis)?; + let packed = pack_view_with(&**packer, &input.view(), self.k_axis, self.mn_axis)?; PackedMatrixStorage::new_batched(&output_shape, vec![packed]) .into_tensor(input.datum_type()) } else { @@ -106,7 +125,8 @@ impl OptMatMulPack { .map(|(x, s)| *x as isize * s) .sum::() * input.datum_type().size_of() as isize; - values.push(packer.pack_tensor_view( + values.push(pack_view_with( + &**packer, &TensorView::from_bytes(&input, offset, input.shape(), input.strides()), self.k_axis, self.mn_axis, @@ -131,12 +151,17 @@ impl OptMatMulPack { pub struct DynPackedExoticFact { pub k: TDim, pub mn: TDim, - pub packers: Vec, + pub packers: Vec>, } impl ExoticFact for DynPackedExoticFact { fn buffer_sizes(&self) -> TVec { - tvec!(self.k.clone() * &self.mn * self.packers[0].dt.size_of()) + let elem_bytes = match self.packers[0].precursor() { + WeightType::Plain(dt) => dt.size_of(), + // OptMatMulPack only ever carries plain (PackedFormat / PackedI8K4) packers. + WeightType::BlockQuant(_) => 1, + }; + tvec!(self.k.clone() * &self.mn * elem_bytes) } } diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 68955ac179..bb1c5bb3f1 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -83,6 +83,10 @@ harness = false name = "mm_for_asr_am" harness = false +[[bench]] +name = "qmmm_i8" +harness = false + [[bench]] name = "hardswish" harness = false diff --git a/linalg/arm64/arm64simd/arm64simd_mmm_i32_8x8_dot.S.j2 b/linalg/arm64/arm64simd/arm64simd_mmm_i32_8x8_dot.S.j2 new file mode 100644 index 0000000000..0a63a74edf --- /dev/null +++ b/linalg/arm64/arm64simd/arm64simd_mmm_i32_8x8_dot.S.j2 @@ -0,0 +1,235 @@ +// vim: ft=arm + +// C tile regs: +// - x19-x29 to preserve (but x19, x28, x29 not used) +// - d8..d15 to preserve +// - v16 to v31, no need to preserve +// +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] +// v16[2] v18[2] +// v16[3] v18[3] +// +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] +// v17[2] v19[2] +// v17[3] v19[3] + +// no preservation either for v0-v7... +// packed A buffering (2x8 values): alternating v0, v1 with v2, v3 +// packed B buffering (2x8 values): alternating v4, v5 with v6, v7 + +.text +.align 4 + +.cpu generic+fp+simd+dotprod +.global {{G}}arm64simd_mmm_i32_8x8_dot_{{suffix}} +{{G}}arm64simd_mmm_i32_8x8_dot_{{suffix}}: + +/* + prfm pldl1keep, [x1] + prfm pldl1keep, [x2] +*/ + stp x20, x21, [sp, #-16]! + stp x22, x23, [sp, #-16]! + stp x24, x25, [sp, #-16]! + stp x26, x27, [sp, #-16]! + + stp d8, d9, [sp, #-16]! + stp d10, d11, [sp, #-16]! + stp d12, d13, [sp, #-16]! + stp d14, d15, [sp, #-16]! + +{% include "dispatcher.j2" %} + +.add_mat_mul: + ldp x2, x4, [x0, #24] // b, packing + ldp x3, x1, [x0, #8] // k, a + + cmp x3, #0 + beq .non_linear_loop + + cmp x4, #1 + beq .packed_packed_loop_1_i8i8 + +.packed_packed_loop_1: + + ld1 { v0.4s, v1.4s }, [ x1 ], #32 + ld1 { v4.4s, v5.4s }, [ x2 ], #32 + + mla v16.4s, v0.4s, v4.s[0] + mla v17.4s, v1.4s, v4.s[0] + mla v18.4s, v0.4s, v4.s[1] + mla v19.4s, v1.4s, v4.s[1] + + mla v20.4s, v0.4s, v4.s[2] + mla v21.4s, v1.4s, v4.s[2] + mla v22.4s, v0.4s, v4.s[3] + mla v23.4s, v1.4s, v4.s[3] + + mla v24.4s, v0.4s, v5.s[0] + mla v25.4s, v1.4s, v5.s[0] + mla v26.4s, v0.4s, v5.s[1] + mla v27.4s, v1.4s, v5.s[1] + + mla v28.4s, v0.4s, v5.s[2] + mla v29.4s, v1.4s, v5.s[2] + mla v30.4s, v0.4s, v5.s[3] + mla v31.4s, v1.4s, v5.s[3] + + subs x3, x3, #1 + bne .packed_packed_loop_1 + + b .non_linear_loop + +.packed_packed_loop_1_i8i8: + // PackedI8K4 (K=4-inner, r=8): per 4-K block, A is m0-3 (v0) / m4-7 (v1), + // B is n0-3 (v4) / n4-7 (v5), each lane a 4xi8 group. SDOT by-element dots + // a B column's 4 K against all 4 m rows of an A half. Same v16..v31 tile + // layout as the SMLAL kernel: v[16 + n*2 + m_half] = C[m_half*4..][n]. + ld1 { v0.16b, v1.16b }, [ x1 ], #32 + ld1 { v4.16b, v5.16b }, [ x2 ], #32 + + sdot v16.4s, v0.16b, v4.4b[0] + sdot v17.4s, v1.16b, v4.4b[0] + sdot v18.4s, v0.16b, v4.4b[1] + sdot v19.4s, v1.16b, v4.4b[1] + sdot v20.4s, v0.16b, v4.4b[2] + sdot v21.4s, v1.16b, v4.4b[2] + sdot v22.4s, v0.16b, v4.4b[3] + sdot v23.4s, v1.16b, v4.4b[3] + + sdot v24.4s, v0.16b, v5.4b[0] + sdot v25.4s, v1.16b, v5.4b[0] + sdot v26.4s, v0.16b, v5.4b[1] + sdot v27.4s, v1.16b, v5.4b[1] + sdot v28.4s, v0.16b, v5.4b[2] + sdot v29.4s, v1.16b, v5.4b[2] + sdot v30.4s, v0.16b, v5.4b[3] + sdot v31.4s, v1.16b, v5.4b[3] + + subs x3, x3, #4 + bgt .packed_packed_loop_1_i8i8 + + b .non_linear_loop + +{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_i32_scalars.j2" %} +{% set mr = 8 %}{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_i32_per_rows.j2" %} +{% set mr = 8 %}{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_i32_per_cols.j2" %} +{% set from = 16 %}{% set to = 31 %}{% include "arm64simd_mmm_load_tile.j2" %} + +.add_unicast: + ldp x5, x6, [x0, #8] + ldp x7, x8, [x0, #24] + + cmp x8, #4 + beq non_linear_addc_i32 + + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + ld1 {v0.b}[{{lane}}], [ x4 ], x6 + {% endfor %} + sshll v0.8h, v0.8b, 0 + sshll v0.4s, v0.4h, 0 + add v{{ col * 2 + reg }}.4s, v{{ col * 2 + reg }}.4s, v0.4s + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +non_linear_addc_i32: + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + ld1 {v0.s}[{{lane}}], [ x4 ], x6 + {% endfor %} + add v{{ col * 2 + reg }}.4s, v{{ col * 2 + reg }}.4s, v0.4s + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +.add_row_col_products: + ldr x2, [x0, #8] + ldr x3, [x0, #16] + + ld1 { v0.4s, v1.4s }, [ x2 ] + ld1 { v4.4s, v5.4s }, [ x3 ] + + xtn v0.4h, v0.4s + xtn v1.4h, v1.4s + xtn v4.4h, v4.4s + xtn v5.4h, v5.4s + + smlal v16.4s, v0.4h, v4.h[0] + smlal v17.4s, v1.4h, v4.h[0] + smlal v18.4s, v0.4h, v4.h[1] + smlal v19.4s, v1.4h, v4.h[1] + smlal v20.4s, v0.4h, v4.h[2] + smlal v21.4s, v1.4h, v4.h[2] + smlal v22.4s, v0.4h, v4.h[3] + smlal v23.4s, v1.4h, v4.h[3] + + smlal v24.4s, v0.4h, v5.h[0] + smlal v25.4s, v1.4h, v5.h[0] + smlal v26.4s, v0.4h, v5.h[1] + smlal v27.4s, v1.4h, v5.h[1] + smlal v28.4s, v0.4h, v5.h[2] + smlal v29.4s, v1.4h, v5.h[2] + smlal v30.4s, v0.4h, v5.h[3] + smlal v31.4s, v1.4h, v5.h[3] + + b .non_linear_loop + + {% include "arm64simd_mmm_i32_scale_q16_q31.j2" %} + +.store: + ldp x5, x6, [x0, #8] // c base ptr, rsc + ldp x7, x8, [x0, #24] // csc, item_size + + cmp x8, #4 + beq .store_strides_i32 + + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + st1 { v{{ col * 2 + reg }}.b }[{{ lane * 4 }}], [ x4 ], x6 + {% endfor %} + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +.store_strides_i32: + {% for col in range(8, 16) %} + mov x4, x5 + {% for reg in range(0, 2) %} + {% for lane in range(0, 4) %} + st1 { v{{ col * 2 + reg }}.s }[{{lane}}], [ x4 ], x6 + {% endfor %} + {% endfor %} + add x5, x5, x7 + {% endfor %} + + b .non_linear_loop + +.return: + ldp d14, d15, [sp], #16 + ldp d12, d13, [sp], #16 + ldp d10, d11, [sp], #16 + ldp d8, d9, [sp], #16 + + ldp x26, x27, [sp], #16 + ldp x24, x25, [sp], #16 + ldp x22, x23, [sp], #16 + ldp x20, x21, [sp], #16 + + ret + diff --git a/linalg/arm64/sme/sme_qmmm_i32_32x32.S.j2 b/linalg/arm64/sme/sme_qmmm_i32_32x32.S.j2 new file mode 100644 index 0000000000..7c7e2f5b75 --- /dev/null +++ b/linalg/arm64/sme/sme_qmmm_i32_32x32.S.j2 @@ -0,0 +1,681 @@ +// vim: ft=arm +// +// SME2 i32 32x32 quantized matmul kernel. +// +// ZA tile layout (4 .S tiles, 16x16 i32 each): +// ZA0.S : C[0..16, 0..16] (top-left) +// ZA1.S : C[0..16, 16..32] (top-right) +// ZA2.S : C[16..32, 0..16] (bottom-left) +// ZA3.S : C[16..32, 16..32] (bottom-right) +// +// Inner K-step (K decrements by 4 per iter, since SMOPA at i8 reduces 4): +// ld1b {z0, z1}, pn8/z, [A] ; 32 M × 4 K = 128 i8 of A +// ld1b {z2, z3}, pn8/z, [B] ; 32 N × 4 K = 128 i8 of B +// smopa za0.s, p0/m, p0/m, z0.b, z2.b ; ZA0 += A[0..16] × B[0..16] +// smopa za1.s, p0/m, p0/m, z0.b, z3.b +// smopa za2.s, p0/m, p0/m, z1.b, z2.b +// smopa za3.s, p0/m, p0/m, z1.b, z3.b +// +// SMOPA at i8 throughput: 4-way K reduction per insn × 16x16 cells = 1024 +// MACs per insn. With 4-tile rotation we approach 4 SMOPAs/cycle = 4 K +// reduction × 16x16 = 4096 MACs/cycle ≈ ~16 TOPS theoretical peak. +// +// Calling convention (extern "C", AAPCS64): +// x0 = const *FusedKerSpec, advanced 40 B per dispatcher iteration +// x1 = 4 KiB scratch buffer for tile spills (used by store-generic / q_scale) +// +// Tract packing requirement: i8 inputs packed with K_alignment=4 (SMOPA +// requires K%4=0). The PackedFormat::with_k_alignment(4) handles this. + +.arch armv9-a+sme2 +.text +.align 4 + +.global {{G}}sme_qmmm_i32_32x32_{{suffix}} +{{G}}sme_qmmm_i32_32x32_{{suffix}}: + + stp q8, q9, [sp, #-128]! + stp q10, q11, [sp, #32] + stp q12, q13, [sp, #64] + stp q14, q15, [sp, #96] + + sub sp, sp, #4096 + mov x1, sp + + smstart + ptrue p0.b + ptrue pn8.b + mov w8, #0 + +{% include "dispatcher.j2" %} + +// -------- AddMatMul: ZA += A·B at i8 with K=4 reduction per SMOPA ---------- + +.add_mat_mul: + ldr x9, [x0, #32] // packing index + ldr x2, [x0, #24] // b ptr + ldp x3, x4, [x0, #8] // k, a ptr + cmp x3, #0 + b.eq .non_linear_loop + cmp x9, #1 + b.eq .Lmatmul_loop +// i32i32 fallback (packing != 1, auto-test path): ZA += A[:,k] (x) B[k,:], one +// K-step at a time via predicated MLA rank-1 updates. One instruction per line: +// the Apple/LLVM AArch64 assembler treats `;` as a COMMENT, so semicolon-packed +// statements silently drop everything after the first `;`. +.Lk32: + ld1w {z2.s}, p0/z, [x2] // B[k, 0..16] + ld1w {z3.s}, p0/z, [x2, #1, mul vl] // B[k, 16..32] + mov w12, #0 +.Lkt: + ldr w10, [x4, w12, uxtw #2] // A[k, w12] + dup z4.s, w10 + mov z16.s, p0/m, za0h.s[w12, 0] + mov z17.s, p0/m, za1h.s[w12, 0] + mla z16.s, p0/m, z2.s, z4.s // C[w12, 0..16] += A[w12] * B[0..16] + mla z17.s, p0/m, z3.s, z4.s // C[w12, 16..32] += A[w12] * B[16..32] + mov za0h.s[w12, 0], p0/m, z16.s + mov za1h.s[w12, 0], p0/m, z17.s + add w10, w12, #16 + ldr w10, [x4, w10, uxtw #2] // A[k, w12+16] + dup z4.s, w10 + mov z18.s, p0/m, za2h.s[w12, 0] + mov z19.s, p0/m, za3h.s[w12, 0] + mla z18.s, p0/m, z2.s, z4.s // C[w12+16, 0..16] += A[w12+16] * B[0..16] + mla z19.s, p0/m, z3.s, z4.s // C[w12+16, 16..32] += A[w12+16] * B[16..32] + mov za2h.s[w12, 0], p0/m, z18.s + mov za3h.s[w12, 0], p0/m, z19.s + add w12, w12, #1 + cmp w12, #16 + b.lt .Lkt + add x4, x4, #128 + add x2, x2, #128 + subs x3, x3, #1 + b.ne .Lk32 + b .non_linear_loop + +.Lmatmul_loop: + ld1b {z0.b, z1.b}, pn8/z, [x4] + ld1b {z2.b, z3.b}, pn8/z, [x2] + add x4, x4, #128 + add x2, x2, #128 + smopa za0.s, p0/m, p0/m, z0.b, z2.b + smopa za1.s, p0/m, p0/m, z0.b, z3.b + smopa za2.s, p0/m, p0/m, z1.b, z2.b + smopa za3.s, p0/m, p0/m, z1.b, z3.b + subs x3, x3, #4 + b.gt .Lmatmul_loop + b .non_linear_loop + +.clear: + zero {za} + b .non_linear_loop + +// -------- Store: i32 tile -> memory (port of Phase 1 f32 store) ----------- + +.store: + ldp x5, x6, [x0, #8] // ptr, row_byte_stride + ldp x7, x9, [x0, #24] // col_byte_stride, item_size + + cmp x7, #4 + b.ne .Lstore_generic + cmp x9, #4 + b.ne .Lstore_generic + + add x4, x5, #64 + mov w12, #0 +.Lstore_top: + st1w {za0h.s[w12, 0]}, p0, [x5] + st1w {za1h.s[w12, 0]}, p0, [x4] + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_top + mov w12, #0 +.Lstore_bot: + st1w {za2h.s[w12, 0]}, p0, [x5] + st1w {za3h.s[w12, 0]}, p0, [x4] + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_bot + b .non_linear_loop + +.Lstore_generic: + mov x13, x9 // preserve item_size before x9 is reused as a ptr + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Lstore_spill_top: + st1w {za0h.s[w12, 0]}, p0, [x4] + st1w {za1h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_spill_top + mov w12, #0 +.Lstore_spill_bot: + st1w {za2h.s[w12, 0]}, p0, [x4] + st1w {za3h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lstore_spill_bot + + mov x3, #0 +.Lstore_row: + mov x4, x5 + mov x10, #0 + lsl x9, x3, #7 + add x11, x1, x9 +.Lstore_col: + ldr w9, [x11], #4 + cmp x13, #1 // item_size: 1 -> strb, 2 -> strh, else (4) -> str + b.eq .Lstore_b1 + cmp x13, #2 + b.eq .Lstore_b2 + str w9, [x4] + b .Lstore_cnext +.Lstore_b1: + strb w9, [x4] + b .Lstore_cnext +.Lstore_b2: + strh w9, [x4] +.Lstore_cnext: + add x4, x4, x7 + add x10, x10, #1 + cmp x10, #32 + b.lt .Lstore_col + add x5, x5, x6 + add x3, x3, #1 + cmp x3, #32 + b.lt .Lstore_row + b .non_linear_loop + +// -------- LoadTile: ZA := row-major i32 tile from memory ------------------- + +.load_tile: + ldr x2, [x0, #16] + add x4, x2, #64 + mov w12, #0 +.Lloadtile_top: + ld1w {z6.s}, p0/z, [x2] + ld1w {z7.s}, p0/z, [x4] + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x2, x2, #128 + add x4, x4, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lloadtile_top + mov w12, #0 +.Lloadtile_bot: + ld1w {z6.s}, p0/z, [x2] + ld1w {z7.s}, p0/z, [x4] + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x2, x2, #128 + add x4, x4, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lloadtile_bot + b .non_linear_loop + +// -------- AddUnicast: ZA += C (strided load + add) ------------------------ + +.add_unicast: + ldp x5, x6, [x0, #8] // ptr, row_byte_stride + ldp x7, x9, [x0, #24] // col_byte_stride, item_size + + cmp x7, #4 + b.ne .Laddu_generic + cmp x9, #4 + b.ne .Laddu_generic + + add x4, x5, #64 + mov w12, #0 +.Laddu_top: + ld1w {z8.s}, p0/z, [x5] + ld1w {z9.s}, p0/z, [x4] + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z9.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_top + mov w12, #0 +.Laddu_bot: + ld1w {z8.s}, p0/z, [x5] + ld1w {z9.s}, p0/z, [x4] + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z9.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x5, x5, x6 + add x4, x4, x6 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_bot + b .non_linear_loop + +.Laddu_generic: + // Strided gather to scratch, then contig accumulate (mirrors Phase 1). + mov x3, #0 + mov x10, x1 +.Laddu_gather_row: + mov x11, x5 + mov x4, #0 +.Laddu_gather_col: + ldr w9, [x11] + str w9, [x10], #4 + add x11, x11, x7 + add x4, x4, #1 + cmp x4, #32 + b.lt .Laddu_gather_col + add x5, x5, x6 + add x3, x3, #1 + cmp x3, #32 + b.lt .Laddu_gather_row + + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Laddu_apply_top: + ld1w {z8.s}, p0/z, [x4] + ld1w {z10.s}, p0/z, [x9] + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z10.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_apply_top + mov w12, #0 +.Laddu_apply_bot: + ld1w {z8.s}, p0/z, [x4] + ld1w {z10.s}, p0/z, [x9] + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + add z6.s, p0/m, z6.s, z8.s + add z7.s, p0/m, z7.s, z10.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Laddu_apply_bot + b .non_linear_loop + +// -------- AddRowColProducts: ZA += rows ⊗ cols (i32 outer product) -------- +// +// rows: 32 i32 (broadcast per M-row), cols: 32 i32 (lane vector per N-col). +// Per ZA row, we need: ZA[i, j] += rows[i] * cols[j]. Slice-by-slice. + +.add_row_col_products: + ldp x2, x3, [x0, #8] // rows ptr, cols ptr + ld1w {z4.s}, p0/z, [x3] // cols[0..16] + ld1w {z5.s}, p0/z, [x3, #1, mul vl] // cols[16..32] + + // Top 16 rows + mov w12, #0 +.Larcp_top: + ldr w9, [x2], #4 + dup z16.s, w9 // broadcast rows[i] to z16 + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + mla z6.s, p0/m, z16.s, z4.s // z6 += z16 * cols[0..16] + mla z7.s, p0/m, z16.s, z5.s // z7 += z16 * cols[16..32] + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .Larcp_top + // Bottom 16 rows + mov w12, #0 +.Larcp_bot: + ldr w9, [x2], #4 + dup z16.s, w9 + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + mla z6.s, p0/m, z16.s, z4.s + mla z7.s, p0/m, z16.s, z5.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .Larcp_bot + b .non_linear_loop + +// -------- scalar fuse ops: broadcast scalar, apply lane-wise -------------- +// +// Sub vs SubF (matches Phase 1's f32 convention): +// ScalarSub → result = scalar - z (mnemonic: subr) +// ScalarSubF → result = z - scalar (mnemonic: sub) + +{% macro scalar_op_i32(label, op) %} +{{label}}: + ldr w2, [x0, #8] + dup z4.s, w2 + mov w12, #0 +.L{{label|replace('.', '')}}_top: + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_top + mov w12, #0 +.L{{label|replace('.', '')}}_bot: + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_bot + b .non_linear_loop +{% endmacro %} + +{{ scalar_op_i32('.scalar_add', 'add') }} +{{ scalar_op_i32('.scalar_mul', 'mul') }} +{{ scalar_op_i32('.scalar_sub', 'subr') }} +{{ scalar_op_i32('.scalar_sub_flipped', 'sub') }} +{{ scalar_op_i32('.scalar_min', 'smin') }} +{{ scalar_op_i32('.scalar_max', 'smax') }} + +// -------- per_col fuse ops: 32-elem vector, broadcast across M rows ------ + +{% macro per_col_op_i32(label, op) %} +{{label}}: + ldr x2, [x0, #8] + ld1w {z4.s}, p0/z, [x2] + ld1w {z5.s}, p0/z, [x2, #1, mul vl] + mov w12, #0 +.L{{label|replace('.', '')}}_top: + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z5.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_top + mov w12, #0 +.L{{label|replace('.', '')}}_bot: + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z5.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_bot + b .non_linear_loop +{% endmacro %} + +{{ per_col_op_i32('.per_col_add', 'add') }} +{{ per_col_op_i32('.per_col_mul', 'mul') }} +{{ per_col_op_i32('.per_col_sub', 'subr') }} +{{ per_col_op_i32('.per_col_sub_flipped', 'sub') }} +{{ per_col_op_i32('.per_col_min', 'smin') }} +{{ per_col_op_i32('.per_col_max', 'smax') }} + +// -------- per_row fuse ops: 32-elem vector, one scalar per M row --------- + +{% macro per_row_op_i32(label, op) %} +{{label}}: + ldr x2, [x0, #8] + add x3, x2, #64 + mov w12, #0 +.L{{label|replace('.', '')}}_top: + ldr w4, [x2], #4 + dup z4.s, w4 + mov z6.s, p0/m, za0h.s[w12, 0] + mov z7.s, p0/m, za1h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_top + mov w12, #0 +.L{{label|replace('.', '')}}_bot: + ldr w4, [x3], #4 + dup z4.s, w4 + mov z6.s, p0/m, za2h.s[w12, 0] + mov z7.s, p0/m, za3h.s[w12, 0] + {{op}} z6.s, p0/m, z6.s, z4.s + {{op}} z7.s, p0/m, z7.s, z4.s + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add w12, w12, #1 + cmp w12, #16 + b.lt .L{{label|replace('.', '')}}_bot + b .non_linear_loop +{% endmacro %} + +{{ per_row_op_i32('.per_row_add', 'add') }} +{{ per_row_op_i32('.per_row_mul', 'mul') }} +{{ per_row_op_i32('.per_row_sub', 'subr') }} +{{ per_row_op_i32('.per_row_sub_flipped', 'sub') }} +{{ per_row_op_i32('.per_row_min', 'smin') }} +{{ per_row_op_i32('.per_row_max', 'smax') }} + +// -------- Quantization fuse ops (bit-exact port of generic/rounding.rs) ---- +// +// Strategy: spill the 32x32 i32 ZA tile to the 4 KiB scratch (x1), quantize +// element-wise in SCALAR GP registers (streaming-mode legal: smull/lsr/asr/ +// cneg/cset/... are base A64 and unaffected by PSTATE.SM), then reload to ZA. +// Quant is not the hot path; this mirrors the scalar approach already proven +// in arm64/sve/sve_mmm_i32.c. Everything is inlined (no `bl` — a nested call +// would clobber x30 and corrupt the final `ret`). +// +// Bit-exactness: the reference forms the FULL i64 product (mult*v) and does a +// single magnitude-rounding shift by (shift+31) with a per-policy nudge. A +// vector sqdmulh+srshl truncates the low 31 bits before the second shift, so +// it is NOT equivalent — hence the i64 scalar port. +// +// RoundingPolicy: Native=0 Zero=1 Away=2 MinusInf=3 PlusInf=4 Even=5 Odd=6. + +// Spill ZA0..ZA3 -> scratch[x1] as a contiguous 32x32 row-major i32 matrix +// (same layout the generic store path uses). Clobbers x4, x9, w12. +{% macro za_spill(sfx) %} + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Lspt_{{sfx}}: + st1w {za0h.s[w12, 0]}, p0, [x4] + st1w {za1h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lspt_{{sfx}} + mov w12, #0 +.Lspb_{{sfx}}: + st1w {za2h.s[w12, 0]}, p0, [x4] + st1w {za3h.s[w12, 0]}, p0, [x9] + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lspb_{{sfx}} +{% endmacro %} + +// Reload scratch[x1] (32x32 row-major i32) -> ZA0..ZA3. Clobbers x4,x9,w12,z6,z7. +{% macro za_reload(sfx) %} + mov x4, x1 + add x9, x1, #64 + mov w12, #0 +.Lrlt_{{sfx}}: + ld1w {z6.s}, p0/z, [x4] + ld1w {z7.s}, p0/z, [x9] + mov za0h.s[w12, 0], p0/m, z6.s + mov za1h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lrlt_{{sfx}} + mov w12, #0 +.Lrlb_{{sfx}}: + ld1w {z6.s}, p0/z, [x4] + ld1w {z7.s}, p0/z, [x9] + mov za2h.s[w12, 0], p0/m, z6.s + mov za3h.s[w12, 0], p0/m, z7.s + add x4, x4, #128 + add x9, x9, #128 + add w12, w12, #1 + cmp w12, #16 + b.lt .Lrlb_{{sfx}} +{% endmacro %} + +// Magnitude-rounding shared by q_scale and q_shr (mirrors `Mul for Scaler` +// / `i32::q_shr`). In: x12 = val (i64), x5 = shift, x6 = policy. Out: w14 (i32). +// Clobbers x13,x15,x16,x17. Preserves x5,x6,x7,x10,x11,x12. +{% macro round_mag(sfx) %} + cmp x5, #0 + b.gt .Lrpos_{{sfx}} + neg x13, x5 + lsl x13, x12, x13 // val << (-shift) + mov w14, w13 + b .Lrend_{{sfx}} +.Lrpos_{{sfx}}: + cmp x12, #0 + cneg x15, x12, mi // x15 = |val| + sub x13, x5, #1 + mov x16, #1 + lsl x16, x16, x13 // x16 = half = 1 << (shift-1) + cmp x6, #2 // Away -> nudge 0 + b.eq .Lrn0_{{sfx}} + cmp x6, #1 // Zero -> nudge -1 + b.ne .Lrna_{{sfx}} + mov x17, #-1 + b .Lrnd_{{sfx}} +.Lrna_{{sfx}}: + cmp x6, #3 // MinusInf -> -(val >= 0) + b.ne .Lrnb_{{sfx}} + cmp x12, #0 + cset x17, ge + neg x17, x17 + b .Lrnd_{{sfx}} +.Lrnb_{{sfx}}: + cmp x6, #4 // PlusInf -> -(val <= 0) + b.ne .Lrnc_{{sfx}} + cmp x12, #0 + cset x17, le + neg x17, x17 + b .Lrnd_{{sfx}} +.Lrnc_{{sfx}}: + cmp x6, #5 // Even -> ((|val|>>shift)&1) - 1 + b.ne .Lrno_{{sfx}} + lsr x17, x15, x5 + and x17, x17, #1 + sub x17, x17, #1 + b .Lrnd_{{sfx}} +.Lrno_{{sfx}}: // Odd -> -((|val|>>shift)&1) + lsr x17, x15, x5 + and x17, x17, #1 + neg x17, x17 + b .Lrnd_{{sfx}} +.Lrn0_{{sfx}}: + mov x17, #0 +.Lrnd_{{sfx}}: + add x15, x15, x16 + add x15, x15, x17 + lsr x15, x15, x5 // (|val| + half + nudge) >> shift + cmp x12, #0 + cneg x14, x15, mi // signum(val) * mag +.Lrend_{{sfx}}: +{% endmacro %} + +// QScale(shift, policy, mult): val = mult*v (i64); shift += 31; magnitude round. +.q_scale: + ldr x5, [x0, #8] // shift (isize) + ldr x6, [x0, #16] // policy + ldr w7, [x0, #24] // mult (i32) + add x5, x5, #31 + {{ za_spill('qsc') }} + mov x10, x1 + mov x11, #1024 +.Lqsc_loop: + ldr w9, [x10] + smull x12, w7, w9 // val = (i64)mult * (i64)v + {{ round_mag('qsc') }} + str w14, [x10], #4 + subs x11, x11, #1 + b.ne .Lqsc_loop + {{ za_reload('qsc') }} + b .non_linear_loop + +// RoundingShiftRight(shift, policy): val = v (i64); magnitude round (shift>0). +.q_shr: + ldr x5, [x0, #8] // shift (usize, >= 1) + ldr x6, [x0, #16] // policy + {{ za_spill('qsr') }} + mov x10, x1 + mov x11, #1024 +.Lqsr_loop: + ldr w9, [x10] + sxtw x12, w9 // val = (i64)v + {{ round_mag('qsr') }} + str w14, [x10], #4 + subs x11, x11, #1 + b.ne .Lqsr_loop + {{ za_reload('qsr') }} + b .non_linear_loop + +// ShiftLeft(shift): result = v << shift (32-bit wrapping, matches i32::q_shl). +.q_shl: + ldr x5, [x0, #8] // shift (usize) + {{ za_spill('qsl') }} + mov x10, x1 + mov x11, #1024 +.Lqsl_loop: + ldr w9, [x10] + lsl w9, w9, w5 + str w9, [x10], #4 + subs x11, x11, #1 + b.ne .Lqsl_loop + {{ za_reload('qsl') }} + b .non_linear_loop + +// -------- LeakyRelu (excluded via CAN_FUSE_I32) --------------------------- + +.leaky_relu: + b .unsupported + +// -------- epilogue -------------------------------------------------------- + +.return: + smstop + add sp, sp, #4096 + ldp q14, q15, [sp, #96] + ldp q12, q13, [sp, #64] + ldp q10, q11, [sp, #32] + ldp q8, q9, [sp], #128 + ret diff --git a/linalg/benches/qmmm_i8.rs b/linalg/benches/qmmm_i8.rs new file mode 100644 index 0000000000..38c234b771 --- /dev/null +++ b/linalg/benches/qmmm_i8.rs @@ -0,0 +1,56 @@ +// int8 -> i32 GEMM (qmmm_i32) microbench. A/B the SME SMOPA kernel vs the NEON +// fallback by running twice: default (SME) vs TRACT_SME_DISABLE=1 (arm64simd 8x8). +extern crate criterion; +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec}; + +use DatumType::I32; + +fn qmmm(be: &mut criterion::Bencher, &(m, k, n): &(usize, usize, usize)) { + unsafe { + let mmm = tract_linalg::ops().mmm(I32, Some(m), Some(k), Some(n)).unwrap(); + // packing index 1 == i8i8 for both sme_qmmm_i32_32x32 and arm64simd_mmm_i32_8x8. + let a = Tensor::zero::(&[m, k]).unwrap(); + let b = Tensor::zero::(&[k, n]).unwrap(); + let packing = &mmm.packings()[1]; + let pa = packing.0.prepare_one(&a, 1, 0).unwrap(); + let pb = packing.1.prepare_one(&b, 0, 1).unwrap(); + let mut c = Tensor::zero::(&[m, n]).unwrap(); + be.iter(move || { + mmm.run( + m, + n, + &[ + FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 1, + }, + FusedSpec::Store(mmm.c_view(Some(0), Some(1)).wrap(&c.view_mut())), + ], + ) + }); + } +} + +fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("qmmm_i8"); + g.sample_size(20); + for &shape in &[ + (256usize, 256usize, 256usize), + (512, 512, 512), + (1024, 1024, 1024), + (128, 768, 768), + (384, 768, 768), + (64, 2048, 2048), + ] { + let (m, k, n) = shape; + g.throughput(Throughput::Elements((m * k * n) as u64)); + g.bench_function(format!("{m}x{k}x{n}"), |be| qmmm(be, &shape)); + } + g.finish(); +} + +criterion::criterion_group!(benches, bench); +criterion::criterion_main!(benches); diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index 47f7cd4ccf..3ee32dd46f 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -189,6 +189,40 @@ pub fn has_fp16() -> bool { || *HAS_FP16 } +// FEAT_DotProd (SDOT/UDOT), ARMv8.2. TRACT_DOTPROD_DISABLE=1 forces it off so +// callers can A/B the SDOT kernel against the SMLAL 8x8 fallback on one binary. +#[cfg(target_os = "macos")] +pub fn has_dotprod() -> bool { + // Every Apple arm64 CPU (M1+/A11+) implements FEAT_DotProd. + std::env::var_os("TRACT_DOTPROD_DISABLE").is_none() +} + +#[cfg(target_os = "linux")] +pub fn has_dotprod() -> bool { + if std::env::var_os("TRACT_DOTPROD_DISABLE").is_some() { + return false; + } + // HWCAP_ASIMDDP = 1 << 20 on aarch64. + const HWCAP_ASIMDDP: u64 = 1 << 20; + const AT_HWCAP: u64 = 16; + unsafe extern "C" { + fn getauxval(t: u64) -> u64; + } + unsafe { (getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0 } +} + +#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "ios")))] +pub fn has_dotprod() -> bool { + false +} + +#[cfg(target_os = "ios")] +pub fn has_dotprod() -> bool { + // A11+ (iPhone10,1+) implement FEAT_DotProd. + std::env::var_os("TRACT_DOTPROD_DISABLE").is_none() + && IPHONE_MODEL_MAJOR.map(|it| it >= 10).unwrap_or(false) +} + #[target_feature(enable = "fp16")] #[inline] pub unsafe fn add_f16(a: f16, b: f16) -> f16 { @@ -351,7 +385,12 @@ pub fn plug(ops: &mut Ops) { arm64fp16::plug(ops); } - ops.qmmm_i32 = Box::new(|_, _, _| arm64simd_mmm_i32_8x8.mmm()); + // SDOT (~4x the SMLAL 8x8) when FEAT_DotProd is present, else the SMLAL 8x8 fallback. + if has_dotprod() { + ops.qmmm_i32 = Box::new(|_, _, _| arm64simd_mmm_i32_8x8_dot.mmm()); + } else { + ops.qmmm_i32 = Box::new(|_, _, _| arm64simd_mmm_i32_8x8.mmm()); + } ops.qmmv_i32 = Box::new(|_, _| arm64simd_mmm_i32_64x1.mmm()); ops.mmv_f32 = match *KIND { Kind::CortexA53 => Box::new(|_, _| arm64simd_mmm_f32_64x1_a53.mmm()), diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index f90d92926f..7d2a20d108 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -84,6 +84,16 @@ MMMExternKernel!(arm64simd_mmm_i32_8x8(8, 8)@(16, 16) store(i8) ); +// SDOT (FEAT_DotProd) variant: 4-K reduction per instruction (~4x the SMLAL +// 8x8 above). Uses the K=4-inner PackedI8K4 packing; identical v16..v31 tile +// layout, so it reuses all the i32 fuse/store/q_scale machinery. +MMMExternKernel!(arm64simd_mmm_i32_8x8_dot(8, 8)@(16, 16) + where(super::has_dotprod) + packing[1] = i8i8 => |k| k.with_packing(crate::pack::PackedI8K4::new(8), crate::pack::PackedI8K4::new(8)); + quality(ManuallyOptimized) + store(i8) +); + MMMExternKernel!(arm64simd_mmm_i32_64x1(64, 1)@(16, 1) packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 64,16), PackedFormat::new(DatumType::I8, 1, 1)); quality(ManuallyOptimized) @@ -110,6 +120,7 @@ pub fn plug(ops: &mut Ops) { arm64simd_mmm_f32_64x1_a53.mmm(), arm64simd_mmm_f32_64x1_a55.mmm(), arm64simd_mmm_i32_8x8.mmm(), + arm64simd_mmm_i32_8x8_dot.mmm(), arm64simd_mmm_i32_64x1.mmm(), ]); panel_extract::plug(ops); diff --git a/linalg/src/arm64/sme.rs b/linalg/src/arm64/sme.rs index 5b8b194d6c..a6adb75d68 100644 --- a/linalg/src/arm64/sme.rs +++ b/linalg/src/arm64/sme.rs @@ -18,6 +18,13 @@ const CAN_FUSE: fn(&FusedSpec) -> bool = |f| { const SME: fn() -> bool = has_sme; const SME2: fn() -> bool = has_sme2; +// The SMOPA i32 kernel implements the quant fuse ops (QScale / RoundingShiftRight +// / ShiftLeft) bit-exactly; only LeakyRelu is unsupported (kernel returns 1). +const CAN_FUSE_I32: fn(&FusedSpec) -> bool = |f| !matches!(f, FusedSpec::LeakyRelu(_)); + +MMMExternKernel!(sme_qmmm_i32_32x32(32,32)@(128,128) where(SME2) can_fuse(CAN_FUSE_I32) + packing[1] = i8i8 => |k| k.with_packing(crate::pack::PackedI8K4::new(32), crate::pack::PackedI8K4::new(32)); + quality(ManuallyOptimized) store(i8)); // Streaming vector length in bytes, read via `RDSVL x0, #1` (encoding // 0x04bf5820). RDSVL is legal in non-streaming mode, but is UNDEFINED @@ -188,7 +195,8 @@ pub fn plug(ops: &mut Ops) { if has_sme2() { log::info!("SME2 GEMV optimisation activated"); ops.mmv_f32 = Box::new(|_, _| sme_mmv_f32_64x1.mmm()); - ops.mmm_impls.extend_from_slice(&[sme_mmv_f32_64x1.mmm()]); + ops.qmmm_i32 = Box::new(|_, _, _| sme_qmmm_i32_32x32.mmm()); + ops.mmm_impls.extend_from_slice(&[sme_mmv_f32_64x1.mmm(), sme_qmmm_i32_32x32.mmm()]); } if !has_sme() && !has_sme2() { log::info!("No SME optimisation"); diff --git a/linalg/src/frame/mmm/mod.rs b/linalg/src/frame/mmm/mod.rs index bae134475c..bbe84cf8b9 100644 --- a/linalg/src/frame/mmm/mod.rs +++ b/linalg/src/frame/mmm/mod.rs @@ -76,6 +76,10 @@ pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any { fn quality(&self) -> ImplementationQuality; fn dynamic_boost(&self) -> isize; + /// Whether this kernel is runnable on the current CPU (platform feature + /// gate, e.g. FEAT_DotProd for the SDOT i8 kernel). + fn is_supported_here(&self) -> bool; + #[allow(clippy::type_complexity)] fn packings(&self) -> &[(Box, Box)]; @@ -145,6 +149,10 @@ impl MatMatMul for K { MatMatMulKer::dynamic_boost(self) } + fn is_supported_here(&self) -> bool { + MatMatMulKer::is_supported_here(self) + } + fn packings(&self) -> &[(Box, Box)] { self.packings() } diff --git a/linalg/src/frame/mmm/tests/packed_packed.rs b/linalg/src/frame/mmm/tests/packed_packed.rs index 63248d2ed2..8437990c72 100644 --- a/linalg/src/frame/mmm/tests/packed_packed.rs +++ b/linalg/src/frame/mmm/tests/packed_packed.rs @@ -1,7 +1,7 @@ +use crate::WeightType; use crate::block_quant::PackedBlockQuantFormat; use crate::mmm::tests::display_error; use crate::mmm::{AsInputValue, FusedKerSpec, FusedSpec, MatMatMul, MatMatMulKer, OutputStoreKer}; -use crate::pack::PackedFormat; use proptest::collection::vec; use proptest::prelude::*; use std::fmt::Debug; @@ -255,9 +255,8 @@ impl PackedPackedProblem { pub fn padded_inputs(&self) -> TractResult<(Tensor, Tensor)> { let (pack_a, pack_b) = &self.ker.packings()[self.packing]; - assert!(pack_b.k_alignment() == 1); let (m, k, n) = self.mkn(); - let k_aligned = k.next_multiple_of(pack_a.k_alignment()); + let k_aligned = k.next_multiple_of(pack_a.k_alignment().max(pack_b.k_alignment())); let mut a = Tensor::zero::(&[m, k_aligned])?; for row in 0..m { @@ -265,8 +264,8 @@ impl PackedPackedProblem { a.try_as_plain_mut()?.to_array_view_mut()?[[row, col]] = self.a[col + k * row]; } } - if let Some(pf) = pack_a.downcast_ref::() { - a = a.cast_to_dt(pf.dt)?.into_owned(); + if let WeightType::Plain(dt) = pack_a.precursor() { + a = a.cast_to_dt(dt)?.into_owned(); } let mut b = Tensor::zero::(&[k_aligned, n])?; for row in 0..k { @@ -274,8 +273,8 @@ impl PackedPackedProblem { b.try_as_plain_mut()?.to_array_view_mut()?[[row, col]] = self.b[col + n * row]; } } - if let Some(pf) = pack_b.downcast_ref::() { - b = b.cast_to_dt(pf.dt)?.into_owned(); + if let WeightType::Plain(dt) = pack_b.precursor() { + b = b.cast_to_dt(dt)?.into_owned(); } Ok((a, b)) @@ -283,9 +282,9 @@ impl PackedPackedProblem { pub fn reference(&self) -> TractResult { let (m, k, n) = self.mkn(); - let pack_a = &self.ker.packings()[self.packing].0; + let (pack_a, pack_b) = &self.ker.packings()[self.packing]; let (mut a, b) = self.padded_inputs()?; - let k_aligned = k.next_multiple_of(pack_a.k_alignment()); + let k_aligned = k.next_multiple_of(pack_a.k_alignment().max(pack_b.k_alignment())); if let Some(pbqf) = pack_a.downcast_ref::() { a = pbqf.simulate_precision_loss(a, 1)?; }; @@ -312,8 +311,7 @@ impl PackedPackedProblem { pub fn run(&self) -> TractResult { let (m, k, n) = self.mkn(); let (pack_a, pack_b) = &self.ker.packings()[self.packing]; - assert!(pack_b.k_alignment() == 1); - let k_aligned = k.next_multiple_of(pack_a.k_alignment()); + let k_aligned = k.next_multiple_of(pack_a.k_alignment().max(pack_b.k_alignment())); let (a, b) = self.padded_inputs()?; let pa = pack_a.prepare_one(&a, 1, 0)?; diff --git a/linalg/src/frame/pack.rs b/linalg/src/frame/pack.rs index ed9d10ed25..a9a5699583 100644 --- a/linalg/src/frame/pack.rs +++ b/linalg/src/frame/pack.rs @@ -662,6 +662,111 @@ where } } +// K=4-inner packing writer (SMOPA / PackedI8K4 layout), fed in K-OUTER order +// (same feed as KOutWriter, used by the im2col patchers): for each k, all mn. +// Within a panel, element (k, local_mn) lands at (k/4)*r*4 + local_mn*4 + (k%4), +// so consecutive mn for a fixed k are stride-4 stores. +#[derive(Debug)] +pub struct KOut4Writer<'p, T> +where + T: Copy + std::fmt::Debug, +{ + base: *mut T, + r4: usize, // r * 4 + panel_len: usize, // k_aligned * r + panels: usize, + panel_width: usize, + last_panel_width: usize, + kb: usize, // k / 4 + kr: usize, // k % 4 + panel: usize, + local_mn: usize, + _phantom: PhantomData<&'p T>, +} + +impl<'p, T> KOut4Writer<'p, T> +where + T: Copy + std::fmt::Debug, +{ + pub fn new(base: *mut T, r: usize, panel_len: usize, mn: usize) -> KOut4Writer<'p, T> { + let panels = mn.divceil(r).max(1); + let last_panel_width = mn - (panels - 1) * r; + KOut4Writer { + base, + r4: r * 4, + panel_len, + panels, + panel_width: r, + last_panel_width, + kb: 0, + kr: 0, + panel: 0, + local_mn: 0, + _phantom: PhantomData, + } + } + #[inline(always)] + fn panel_width(&self) -> usize { + if self.panel == self.panels - 1 { self.last_panel_width } else { self.panel_width } + } + #[inline(always)] + fn advance(&mut self, by: usize) { + self.local_mn += by; + if self.local_mn >= self.panel_width() { + self.local_mn = 0; + self.panel += 1; + if self.panel == self.panels { + self.panel = 0; + self.kr += 1; + if self.kr == 4 { + self.kr = 0; + self.kb += 1; + } + } + } + } +} + +impl PackingWriter for KOut4Writer<'_, T> +where + T: Copy + std::fmt::Debug, +{ + #[inline(always)] + fn write(&mut self, t: T) { + unsafe { + let off = self.panel * self.panel_len + self.kb * self.r4 + self.local_mn * 4 + self.kr; + *self.base.add(off) = t; + } + self.advance(1); + } + + #[inline] + fn write_slice(&mut self, ts: &[T]) { + let n = ts.len(); + if n == 0 { + return; + } + let pw = self.panel_width(); + if self.local_mn + n <= pw { + // Whole slice stays inside the current (panel, k): tight stride-4 store. + unsafe { + let mut d = self.base.add( + self.panel * self.panel_len + self.kb * self.r4 + self.local_mn * 4 + self.kr, + ); + for &t in ts { + *d = t; + d = d.add(4); + } + } + self.advance(n); + } else { + for &t in ts { + self.write(t); + } + } + } +} + #[inline(never)] unsafe fn pack_mn_major( b: *const u8, @@ -1091,3 +1196,128 @@ mod test { .check(); } } + +// K=4-inner packing for SMOPA-style int8 matmul: 4 contiguous K per mn-lane. +// Layout: out[(k/4)*r*4 + m*4 + (k%4)] = src[m,k]. k_alignment=4. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PackedI8K4 { + pub r: usize, + pub align: usize, +} +impl PackedI8K4 { + pub fn new(r: usize) -> Self { + PackedI8K4 { r, align: 16 } + } + fn panel(&self, k: usize) -> usize { + (k.div_ceil(4) * 4) * self.r + } + // Buffer geometry for direct (one-pass) packing via KOut4Writer. + pub fn single_panel_len(&self, k: usize) -> usize { + self.panel(k) + } + pub fn len(&self, k: usize, mn: usize) -> usize { + mn.divceil(self.r) * self.panel(k) + } + pub fn alignment(&self) -> usize { + self.align + } + pub fn write_with_k_outer<'p, T: Copy + std::fmt::Debug>( + &self, + pb: *mut T, + k: usize, + mn: usize, + ) -> KOut4Writer<'p, T> { + KOut4Writer::new(pb, self.r, self.panel(k), mn) + } + // K=4-inner pack from a (possibly strided) view: out[(k/4)*r*4 + m*4 + (k%4)] = src[m,k]. + pub fn pack_view( + &self, + t: &TensorView, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + let mn = t.shape()[mn_axis]; + let kp = k.div_ceil(4) * 4; + let pl = kp * self.r; + let panels = mn.div_ceil(self.r); + let st = t.strides(); + let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) }; + blob.as_bytes_mut().fill(0); + // Cache-friendly K=4-inner pack: outer (panel, k-block, k%4), inner over the + // r mn-lanes. Per (kb, kr) the source K-row is read sequentially (stride mn) + // and the destination column is written at stride 4. No per-element div/mod. + let (ks, ms) = (st[k_axis], st[mn_axis]); + let kblocks = kp / 4; + unsafe { + let src = t.as_ptr_unchecked::(); + let dst = blob.as_mut_ptr() as *mut i8; + for p in 0..panels { + let pw = self.r.min(mn - p * self.r); + let panel = dst.add(p * pl); + let mn0 = (p * self.r) as isize; + for kb in 0..kblocks { + for kr in 0..4 { + let kk = kb * 4 + kr; + if kk >= k { + break; + } + let srow = src.offset(kk as isize * ks + mn0 * ms); + let dcol = panel.add(kb * self.r * 4 + kr); + for lm in 0..pw { + *dcol.add(lm * 4) = *srow.offset(lm as isize * ms); + } + } + } + } + } + Ok(Box::new(EagerPackedInput { + fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k }, + packed: blob.into(), + panel_bytes: pl, + mn, + })) + } +} +impl std::fmt::Display for PackedI8K4 { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "I8K4[{}]", self.r) + } +} +impl MMMInputFormat for PackedI8K4 { + fn prepare_tensor(&self, t: &Tensor, k: usize, mn: usize) -> TractResult { + Ok(PackedMatrixStorage::new(self.prepare_one(t, k, mn)?).into_tensor(t.datum_type())) + } + fn prepare_one( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + self.pack_view(&t.view(), k_axis, mn_axis) + } + fn precursor(&self) -> WeightType { + WeightType::Plain(i8::datum_type()) + } + fn r(&self) -> usize { + self.r + } + fn k_alignment(&self) -> usize { + 4 + } + fn merge_with<'o, 'a: 'o, 'b: 'o>( + &'a self, + o: &'b dyn MMMInputFormat, + ) -> Option<&'o dyn MMMInputFormat> { + o.downcast_ref::().filter(|x| x.r == self.r).map(|_| self as _) + } + fn mem_size(&self, k: TDim, mn: TDim) -> TDim { + mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0)) + } + fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> { + bail!("no f16 extract") + } + fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> { + bail!("no f32 extract") + } +}