Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};

use tract_linalg::mmm::{MMMInputFormat, MatMatMul};
use tract_linalg::pack::PackedFormat;
use tract_linalg::pack::{PackedFormat, PackedI8K4};

#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
pub struct Conv {
Expand Down Expand Up @@ -123,13 +123,11 @@ impl Conv {
&[kernel],
)?
} else {
let format = format
.downcast_ref::<PackedFormat>()
.context("Expect regular packing for numeric weights")?;
// PackedFormat or a custom numeric packer (e.g. PackedI8K4).
model.wire_node(
format!("{name}.prep_kernel.pack"),
OptMatMulPack {
packers: vec![format.clone()],
packers: vec![dyn_clone::clone_box(format)],
k_axis: 2,
mn_axis: 1,
mode_picker: ModePicker::Single,
Expand Down Expand Up @@ -240,7 +238,11 @@ impl Conv {
&sum_ker_n_g_c,
)?;

ensure!(mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some());
ensure!(
mmm.packings()[packing].1.downcast_ref::<PackedFormat>().is_some()
|| mmm.packings()[packing].1.downcast_ref::<PackedI8K4>().is_some(),
"Im2Col/QSumB support PackedFormat or PackedI8K4 activation packings"
);
let mut sum_x = model.wire_node(
format!("{name}.sum_x"),
super::QSumB { dt: b_fact.datum_type, n, r: mmm.nr(), k },
Expand Down
151 changes: 92 additions & 59 deletions core/src/ops/cnn/conv/im2col.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use tract_linalg::mmm::{
EagerPackedInput, MMMInputValue, MatMatMul, PackedExoticFact, PackedMatrixStorage,
EagerPackedInput, MMMInputFormat, MMMInputValue, MatMatMul, PackedExoticFact,
PackedMatrixStorage,
};
use tract_linalg::pack::{PackedFormat, PackingWriter};
use tract_linalg::pack::{PackedFormat, PackedI8K4, PackingWriter};

use crate::internal::*;
use ndarray::prelude::*;
Expand All @@ -23,7 +24,8 @@ struct SymbolicGeometry {
group: usize,
pool_spec: PoolSpec,
pool_geometry: PoolGeometry,
b_pack: PackedFormat,
// The kernel's activation packing: PackedFormat (K-major) or PackedI8K4 (K=4-inner).
out_format: Box<dyn MMMInputFormat>,
k: usize,
}

Expand All @@ -32,18 +34,18 @@ struct ConcreteGeometry {
pool: ConcretePoolGeometry,
pub n: usize,
k: usize,
pub b_pack: PackedFormat,
pub out_format: Box<dyn MMMInputFormat>,
pub ci_per_group: usize,
patcher: Patcher,
input_shape_with_n: DataShape,
packed_shape: TVec<usize>, // always Batch,Group
}

impl GeometryBound<SymbolicGeometry, ConcreteGeometry> {
pub fn b_pack(&self) -> &PackedFormat {
pub fn out_format(&self) -> &dyn MMMInputFormat {
match self {
GeometryBound::Symbolic(s) => &s.b_pack,
GeometryBound::Concrete(s) => &s.b_pack,
GeometryBound::Symbolic(s) => &*s.out_format,
GeometryBound::Concrete(s) => &*s.out_format,
}
}
pub fn k(&self) -> usize {
Expand Down Expand Up @@ -88,7 +90,7 @@ impl ResolveTo<ConcreteGeometry> for SymbolicGeometry {
n,
k: self.k,
ci_per_group,
b_pack: self.b_pack.clone(),
out_format: self.out_format.clone(),
patcher,
input_shape_with_n,
packed_shape,
Expand All @@ -105,15 +107,10 @@ impl Im2Col {
mmm: Box<dyn MatMatMul>,
packing: usize,
) -> TractResult<Im2Col> {
let b_pack = mmm.packings()[packing]
.1
.downcast_ref::<PackedFormat>()
.context("Im2Col expects regular packed format")?
.clone();

let out_format = dyn_clone::clone_box(&*mmm.packings()[packing].1);
let pool_geometry = pool_spec.compute_geo(input_full_shape)?;
let geometry: GeometryBound<_, _> =
SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, b_pack, k }
SymbolicGeometry { group, pool_spec: pool_spec.clone(), pool_geometry, out_format, k }
.into();
let geometry = geometry.optimize_if(input_full_shape.as_concrete())?;
Ok(Im2Col { pool_spec, group, geometry })
Expand Down Expand Up @@ -156,8 +153,21 @@ impl EvalOp for Im2Col {
if !self.pool_spec.data_format.has_n() {
input.insert_axis(0)?;
}
let panel_bytes =
geometry.b_pack.single_panel_len(geometry.k) * input.datum_type().size_of();
let dt = input.datum_type();
let r = geometry.out_format.r();
// Buffer geometry. zero_init for PackedI8K4: the K=4-inner writer skips
// the K-padding lanes (k..k_aligned), which SMOPA accumulates — they must
// be 0. PackedFormat has no K padding; its mn-padding maps to discarded
// output rows, so uninitialized is fine (matches prior behaviour).
let (single_panel_len, buf_align, zero_init) =
if let Some(pf) = geometry.out_format.downcast_ref::<PackedFormat>() {
(pf.single_panel_len(geometry.k), pf.alignment(), false)
} else if let Some(p4) = geometry.out_format.downcast_ref::<PackedI8K4>() {
(p4.single_panel_len(geometry.k), p4.alignment(), true)
} else {
bail!("Im2Col: unsupported packing format {:?}", geometry.out_format)
};
let panel_bytes = single_panel_len * dt.size_of();

let n_batches = *geometry.input_shape_with_n.n().unwrap_or(&1);
let n_groups = self.group;
Expand All @@ -169,12 +179,15 @@ impl EvalOp for Im2Col {
let n =
if geometry.pool.output_shape.shape.contains(&0) { 0 } else { geometry.n };
let mut data = Tensor::uninitialized_aligned_dt(
input.datum_type(),
&[geometry.b_pack.len(geometry.k, n)],
geometry.b_pack.alignment(),
dt,
&[n.divceil(r) * single_panel_len],
buf_align,
)?;
if zero_init {
data.as_bytes_mut().fill(0);
}
if n > 0 {
dispatch_copy_by_size!(Patcher::patch(input.datum_type())(
dispatch_copy_by_size!(Patcher::patch(dt)(
&geometry.patcher,
&geometry,
&input,
Expand All @@ -185,7 +198,7 @@ impl EvalOp for Im2Col {
}
values.push(Box::new(EagerPackedInput {
fact: PackedExoticFact {
format: Box::new(geometry.b_pack.clone()),
format: geometry.out_format.clone(),
k: geometry.k,
mn: n.to_dim(),
},
Expand All @@ -211,7 +224,7 @@ impl TypedOp for Im2Col {
let output_shape = self.pool_spec.output_shape(&inputs[0].shape)?;
let mn = output_shape.hw_dims().iter().product::<TDim>();
let pof = PackedExoticFact {
format: Box::new(self.geometry.b_pack().clone()),
format: dyn_clone::clone_box(self.geometry.out_format()),
k: self.geometry.k(),
mn,
};
Expand Down Expand Up @@ -259,34 +272,57 @@ impl Patcher {
pack: &'p mut TensorView,
g: usize,
pad_value: Option<&Tensor>,
) -> TractResult<()> {
// Pick the packing writer for the kernel's output format, then run the
// (writer-generic) patcher. PackedFormat keeps the K-major fast path;
// PackedI8K4 writes the SMOPA K=4-inner layout in the same single pass.
let ptr = unsafe { pack.as_slice_mut_unchecked::<T>().as_mut_ptr() };
if let Some(pf) = geo.out_format.downcast_ref::<PackedFormat>() {
let mut w = pf.write_with_k_outer(ptr, geo.k, geo.n);
self.run::<T, _>(geo, input, g, pad_value, &mut w)
} else if let Some(p4) = geo.out_format.downcast_ref::<PackedI8K4>() {
let mut w = p4.write_with_k_outer(ptr, geo.k, geo.n);
self.run::<T, _>(geo, input, g, pad_value, &mut w)
} else {
bail!("Im2Col: unsupported packing format {:?}", geo.out_format)
}
}

fn run<T: Copy + Datum + num_traits::Zero, W: PackingWriter<T>>(
&self,
geo: &ConcreteGeometry,
input: &TensorView,
g: usize,
pad_value: Option<&Tensor>,
writer: &mut W,
) -> TractResult<()> {
match self {
Patcher::Valid1d => Self::valid_1d::<T>(geo, input, pack, g),
Patcher::Valid2d => Self::valid_2d::<T>(geo, input, pack, g),
Patcher::Padded2d => Self::padded_2d::<T>(
Patcher::Valid1d => Self::valid_1d::<T, W>(geo, input, g, writer),
Patcher::Valid2d => Self::valid_2d::<T, W>(geo, input, g, writer),
Patcher::Padded2d => Self::padded_2d::<T, W>(
geo,
input,
pack,
g,
pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
writer,
),
_ => Self::generic::<T>(
_ => Self::generic::<T, W>(
geo,
input,
pack,
g,
pad_value.unwrap_or(&Tensor::zero_scalar::<T>()?),
writer,
),
}
}

#[inline(never)]
fn generic<'p, T: Copy + Datum>(
geometry: &'p ConcreteGeometry,
fn generic<T: Copy + Datum, W: PackingWriter<T>>(
geometry: &ConcreteGeometry,
input: &TensorView,
pack: &'p mut TensorView,
g: usize,
pad_value: &Tensor,
writer: &mut W,
) -> TractResult<()> {
unsafe {
let pad_value = *pad_value.to_scalar_unchecked();
Expand All @@ -307,25 +343,28 @@ impl Patcher {
}
}
}
geometry.b_pack.pack(pack, mega_matrix.view(), 0, 1);
// mega_matrix is [k, n] (k-major); feed K-outer to the writer, which
// lays out the kernel's packing (K-major for PackedFormat, K=4-inner
// for PackedI8K4) — byte-identical to PackedFormat::pack for the former.
let mv = mega_matrix.as_slice_unchecked::<T>();
for kk in 0..geometry.k {
writer.write_slice(&mv[kk * geometry.n..(kk + 1) * geometry.n]);
}
Ok(())
}
}

#[inline(never)]
fn valid_1d<'p, T: Copy + Datum>(
geometry: &'p ConcreteGeometry,
fn valid_1d<T: Copy + Datum, W: PackingWriter<T>>(
geometry: &ConcreteGeometry,
input: &TensorView,
pack: &'p mut TensorView,
g: usize,
writer: &mut W,
) -> TractResult<()> {
unsafe {
let x_stride = *geometry.input_shape_with_n.h_stride() as isize
* geometry.pool.patch.spec.strides[0] as isize;
let c_stride = *geometry.input_shape_with_n.c_stride() as isize;
let pack = pack.as_slice_mut_unchecked::<T>();
let mut writer =
geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
let iptr = input.as_ptr_unchecked::<T>();
let iptr = iptr.add(g * geometry.ci_per_group * geometry.input_shape_with_n.c_stride());
let output_x = *geometry.pool.patch.output_shape.get_unchecked(0);
Expand Down Expand Up @@ -356,16 +395,15 @@ impl Patcher {
}

#[inline(never)]
fn padded_2d<'p, T: Copy + Datum>(
geometry: &'p ConcreteGeometry,
fn padded_2d<T: Copy + Datum, W: PackingWriter<T>>(
geometry: &ConcreteGeometry,
input: &TensorView,
pack: &'p mut TensorView,
g: usize,
pad_value: &Tensor,
writer: &mut W,
) -> TractResult<()> {
unsafe {
let pad_value = *pad_value.to_scalar_unchecked();
let pack = pack.as_slice_mut_unchecked::<T>();
let y_stride = geometry.pool.patch.spec.strides[0] as isize;
let x_stride = geometry.pool.patch.spec.strides[1] as isize;
let shape = &geometry.input_shape_with_n;
Expand All @@ -375,8 +413,6 @@ impl Patcher {
let input_heigth = shape.hw_dims()[0] as isize;
let input_width = shape.hw_dims()[1] as isize;
let kernel_len = geometry.pool.patch.standard_layout_data_field.len();
let mut writer =
geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
let iptr = input.as_ptr_unchecked::<T>();
let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
let output_width = *geometry.pool.patch.output_shape.get_unchecked(1);
Expand All @@ -402,22 +438,22 @@ impl Patcher {
Self::padded_2d_invalid_x_loop(
valid_x_start as usize,
pad_value,
&mut writer,
&mut *writer,
);
Self::padded_2d_valid_x_loop(
valid_x_start,
valid_x_end,
x_stride_ptr,
iptr,
&mut writer,
&mut *writer,
);
Self::padded_2d_invalid_x_loop(
output_width - valid_x_end as usize,
pad_value,
&mut writer,
&mut *writer,
);
} else {
Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut writer);
Self::padded_2d_invalid_x_loop(output_width, pad_value, &mut *writer);
}
}
}
Expand All @@ -427,23 +463,23 @@ impl Patcher {
}

#[inline(never)]
unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum>(
unsafe fn padded_2d_invalid_x_loop<T: Copy + Datum, W: PackingWriter<T>>(
count: usize,
pad_value: T,
writer: &mut tract_linalg::pack::KOutWriter<T>,
writer: &mut W,
) {
for _ in 0..count {
writer.write(pad_value);
}
}

#[inline(never)]
unsafe fn padded_2d_valid_x_loop<T: Copy + Datum>(
unsafe fn padded_2d_valid_x_loop<T: Copy + Datum, W: PackingWriter<T>>(
x_min: isize,
x_max: isize,
x_stride_ptr: isize,
iptr: *const T,
writer: &mut tract_linalg::pack::KOutWriter<T>,
writer: &mut W,
) {
// Fast path: x_stride_ptr == 1 means consecutive x values are at
// consecutive memory addresses, so the inner loop is a contiguous
Expand All @@ -461,22 +497,19 @@ impl Patcher {
}

#[inline(never)]
fn valid_2d<'p, T: Copy + Datum>(
geometry: &'p ConcreteGeometry,
fn valid_2d<T: Copy + Datum, W: PackingWriter<T>>(
geometry: &ConcreteGeometry,
input: &TensorView,
pack: &'p mut TensorView,
g: usize,
writer: &mut W,
) -> TractResult<()> {
unsafe {
let pack = pack.as_slice_mut_unchecked::<T>();
let shape = &geometry.input_shape_with_n;
let y_stride = geometry.pool.patch.spec.strides[0] as isize;
let x_stride = geometry.pool.patch.spec.strides[1] as isize;
let y_stride_ptr = y_stride * *shape.h_stride() as isize;
let x_stride_ptr = x_stride * *shape.w_stride() as isize;
let c_stride_ptr = *shape.c_stride() as isize;
let mut writer =
geometry.b_pack.write_with_k_outer(pack.as_mut_ptr(), geometry.k, geometry.n);
let iptr = input.as_ptr_unchecked::<T>();
let iptr = iptr.add(g * geometry.ci_per_group * shape.c_stride());
let output_y = *geometry.pool.patch.output_shape.get_unchecked(0);
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/conv/lazy_im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl TypedOp for LazyIm2Col {
let exotic_fact = DynPackedExoticFact {
k: self.params.k_byte_offsets.len().to_dim(),
mn: self.params.n_byte_offsets.len().to_dim(),
packers: vec![self.params.packer.clone()],
packers: vec![Box::new(self.params.packer.clone()) as Box<dyn MMMInputFormat>],
};
Ok(tvec!(inputs[0].datum_type.fact([1, self.group]).with_exotic_fact(exotic_fact)))
}
Expand Down
Loading