Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import io.substrait.spark.logical.ToLogicalPlan
import io.substrait.spark.utils.Util

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Cast, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, Or, ScalarSubquery}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.Decimal
import org.apache.spark.substrait.SparkTypeUtil
Expand Down Expand Up @@ -271,22 +271,28 @@ class ToSparkExpression(
arg.accept(expr.declaration(), i, this, context)
}.toList

scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration.key, args)
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
expr.declaration().name() match {
// Special handling for multi-variate AND/OR operators
case "and" if args.size > 2 => args.reduceLeft { (left, right) => And(left, right) }
case "or" if args.size > 2 => args.reduceLeft { (left, right) => Or(left, right) }
case _ =>
scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration.key, args)
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,171 @@
*/
package io.substrait.spark.expression

import io.substrait.dsl.SubstraitBuilder
import io.substrait.expression.{Expression => SExpression, ExpressionCreator}
import io.substrait.extension.DefaultExtensionCatalog
import io.substrait.`type`.TypeCreator
import io.substrait.util.EmptyVisitationContext
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{And, Literal}
import org.apache.spark.sql.catalyst.expressions.{And, Literal, Or}
import org.scalatest.Assertions.assertResult

import scala.jdk.CollectionConverters._

class PredicateSuite extends SparkFunSuite with SubstraitExpressionTestBase {

private val extensions = DefaultExtensionCatalog.DEFAULT_COLLECTION
private val sb = new SubstraitBuilder(extensions)
private val toSparkExpression =
new ToSparkExpression(ToScalarFunction(io.substrait.spark.SparkExtension.SparkScalarFunctions))

test("And") {
runTest("and:bool", And(Literal(true), Literal(false)))
}

test("Multivariate AND with 3 arguments") {
// Create a substrait AND expression with 3 boolean arguments: a AND b AND c
val boolA = ExpressionCreator.bool(false, true)
val boolB = ExpressionCreator.bool(false, false)
val boolC = ExpressionCreator.bool(false, true)

val substraitAnd = sb.and(boolA, boolB, boolC)

// Verify it's an AND function
val scalarFunc = substraitAnd.asInstanceOf[SExpression.ScalarFunctionInvocation]
assertResult("and")(scalarFunc.declaration().name())
assertResult(3)(scalarFunc.arguments().size())

// Convert to Spark expression
val sparkExpr = substraitAnd.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)

// Verify the result is nested And expressions: And(And(a, b), c)
// reduceLeft creates: ((a AND b) AND c)
sparkExpr match {
case And(And(Literal(a, _), Literal(b, _)), Literal(c, _)) =>
assertResult(true)(a)
assertResult(false)(b)
assertResult(true)(c)
case _ =>
fail(s"Expected nested And expressions, got: $sparkExpr")
}
}

test("Multivariate AND with 4 arguments") {
// Create a substrait AND expression with 4 boolean arguments: a AND b AND c AND d
val boolA = ExpressionCreator.bool(false, true)
val boolB = ExpressionCreator.bool(false, true)
val boolC = ExpressionCreator.bool(false, false)
val boolD = ExpressionCreator.bool(false, true)

val substraitAnd = sb.and(boolA, boolB, boolC, boolD)

// Verify it's an AND function with 4 arguments
val scalarFunc = substraitAnd.asInstanceOf[SExpression.ScalarFunctionInvocation]
assertResult("and")(scalarFunc.declaration().name())
assertResult(4)(scalarFunc.arguments().size())

// Convert to Spark expression
val sparkExpr = substraitAnd.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)

// Verify the result is nested And expressions: And(And(And(a, b), c), d)
// reduceLeft creates: (((a AND b) AND c) AND d)
sparkExpr match {
case And(And(And(Literal(a, _), Literal(b, _)), Literal(c, _)), Literal(d, _)) =>
assertResult(true)(a)
assertResult(true)(b)
assertResult(false)(c)
assertResult(true)(d)
case _ =>
fail(s"Expected nested And expressions, got: $sparkExpr")
}
}

test("Multivariate OR with 3 arguments") {
// Create a substrait OR expression with 3 boolean arguments: a OR b OR c
val boolA = ExpressionCreator.bool(false, false)
val boolB = ExpressionCreator.bool(false, true)
val boolC = ExpressionCreator.bool(false, false)

val substraitOr = sb.or(boolA, boolB, boolC)

// Verify it's an OR function
val scalarFunc = substraitOr.asInstanceOf[SExpression.ScalarFunctionInvocation]
assertResult("or")(scalarFunc.declaration().name())
assertResult(3)(scalarFunc.arguments().size())

// Convert to Spark expression
val sparkExpr = substraitOr.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)

// Verify the result is nested Or expressions: Or(Or(a, b), c)
// reduceLeft creates: ((a OR b) OR c)
sparkExpr match {
case Or(Or(Literal(a, _), Literal(b, _)), Literal(c, _)) =>
assertResult(false)(a)
assertResult(true)(b)
assertResult(false)(c)
case _ =>
fail(s"Expected nested Or expressions, got: $sparkExpr")
}
}

test("Multivariate OR with 5 arguments") {
// Create a substrait OR expression with 5 boolean arguments: a OR b OR c OR d OR e
val boolA = ExpressionCreator.bool(false, false)
val boolB = ExpressionCreator.bool(false, false)
val boolC = ExpressionCreator.bool(false, true)
val boolD = ExpressionCreator.bool(false, false)
val boolE = ExpressionCreator.bool(false, false)

val substraitOr = sb.or(boolA, boolB, boolC, boolD, boolE)

// Verify it's an OR function with 5 arguments
val scalarFunc = substraitOr.asInstanceOf[SExpression.ScalarFunctionInvocation]
assertResult("or")(scalarFunc.declaration().name())
assertResult(5)(scalarFunc.arguments().size())

// Convert to Spark expression
val sparkExpr = substraitOr.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)

// Verify the result is nested Or expressions: Or(Or(Or(Or(a, b), c), d), e)
// reduceLeft creates: ((((a OR b) OR c) OR d) OR e)
sparkExpr match {
case Or(Or(Or(Or(Literal(a, _), Literal(b, _)), Literal(c, _)), Literal(d, _)), Literal(e, _)) =>
assertResult(false)(a)
assertResult(false)(b)
assertResult(true)(c)
assertResult(false)(d)
assertResult(false)(e)
case _ =>
fail(s"Expected nested Or expressions, got: $sparkExpr")
}
}

test("Mixed multivariate AND and OR") {
// Create a complex expression: (a AND b AND c) OR (d AND e)
val boolA = ExpressionCreator.bool(false, true)
val boolB = ExpressionCreator.bool(false, true)
val boolC = ExpressionCreator.bool(false, false)
val boolD = ExpressionCreator.bool(false, true)
val boolE = ExpressionCreator.bool(false, true)

val substraitAndLeft = sb.and(boolA, boolB, boolC)
val substraitAndRight = sb.and(boolD, boolE)
val substraitOr = sb.or(substraitAndLeft, substraitAndRight)

// Convert to Spark expression
val sparkExpr = substraitOr.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)

// Verify the structure: Or(And(And(a, b), c), And(d, e))
sparkExpr match {
case Or(And(And(Literal(a, _), Literal(b, _)), Literal(c, _)), And(Literal(d, _), Literal(e, _))) =>
assertResult(true)(a)
assertResult(true)(b)
assertResult(false)(c)
assertResult(true)(d)
assertResult(true)(e)
case _ =>
fail(s"Expected mixed nested And/Or expressions, got: $sparkExpr")
}
}
}
Loading