diff --git a/vortex-array/src/expr/exprs/list_contains.rs b/vortex-array/src/expr/exprs/list_contains.rs index 68d7ed31bad..50ff7356b7e 100644 --- a/vortex-array/src/expr/exprs/list_contains.rs +++ b/vortex-array/src/expr/exprs/list_contains.rs @@ -13,17 +13,15 @@ use vortex_dtype::IntegerPType; use vortex_dtype::Nullability; use vortex_dtype::PTypeDowncastExt; use vortex_dtype::match_each_integer_ptype; -use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_mask::Mask; use vortex_vector::BoolDatum; use vortex_vector::Datum; -use vortex_vector::ScalarOps; use vortex_vector::Vector; -use vortex_vector::VectorMutOps; use vortex_vector::VectorOps; +use vortex_vector::bool::BoolScalar; use vortex_vector::bool::BoolVector; use vortex_vector::listview::ListViewScalar; use vortex_vector::listview::ListViewVector; @@ -128,30 +126,28 @@ impl VTable for ListContains { .try_into() .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?; - let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) { - (true, true) => { - let list = lhs.into_scalar().vortex_expect("scalar").into_list(); - let needle = rhs.into_scalar().vortex_expect("scalar"); - // Convert the needle scalar to a vector with row_count - // elements and reuse constant_list_scalar_contains - let needle_vector = needle.repeat(args.row_count).freeze(); - constant_list_scalar_contains(list, needle_vector)? + match (lhs, rhs) { + (Datum::Scalar(list_scalar), Datum::Scalar(needle_scalar)) => { + let list = list_scalar.into_list(); + let found = list_contains_scalar_scalar(&list, &needle_scalar)?; + Ok(Datum::Scalar(BoolScalar::new(Some(found)).into())) } - (true, false) => constant_list_scalar_contains( - lhs.into_scalar().vortex_expect("scalar").into_list(), - rhs.into_vector().vortex_expect("vector"), - )?, - (false, true) => list_contains_scalar( - lhs.unwrap_into_vector(args.row_count).into_list(), - rhs.into_scalar().vortex_expect("scalar").into_list(), - )?, - (false, false) => { + (Datum::Scalar(list_scalar), Datum::Vector(needle_vector)) => { + let matches = + constant_list_scalar_contains(list_scalar.into_list(), needle_vector)?; + Ok(Datum::Vector(matches.into())) + } + (Datum::Vector(list_vector), Datum::Scalar(needle_scalar)) => { + let matches = + list_contains_scalar(list_vector.into_list(), needle_scalar.into_list())?; + Ok(Datum::Vector(matches.into())) + } + (Datum::Vector(_), Datum::Vector(_)) => { vortex_bail!( "ListContains currently only supports constant needle (RHS) or constant list (LHS)" ) } - }; - Ok(Datum::Vector(matches.into())) + } } fn stat_falsification( @@ -330,6 +326,35 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex Ok(result) } +/// Used when the needle is a scalar checked for containment in a single list. +fn list_contains_scalar_scalar( + list: &ListViewScalar, + needle: &vortex_vector::Scalar, +) -> VortexResult { + let elements = list.value().elements(); + + // Note: If the comparison becomes a bottleneck, look into faster ways to check for list + // containment. `execute` allocates the returned vector on the heap. Further, the `eq` + // comparison does not short-circuit on the first match found. + let found = Binary + .bind(operators::Operator::Eq) + .execute(ExecutionArgs { + datums: vec![ + Datum::Vector(elements.deref().clone()), + Datum::Scalar(needle.clone()), + ], + dtypes: vec![], + row_count: elements.len(), + return_dtype: DType::Bool(Nullability::Nullable), + })? + .unwrap_into_vector(elements.len()) + .into_bool() + .into_bits(); + + let mut true_bits = BitIndexIterator::new(found.inner().as_ref(), 0, found.len()); + Ok(true_bits.next().is_some()) +} + /// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a /// [`BoolArray`] of matches on the child elements array. ///