diff --git a/core/src/main/java/io/substrait/expression/EnumArg.java b/core/src/main/java/io/substrait/expression/EnumArg.java index 041b62104..d77be204f 100644 --- a/core/src/main/java/io/substrait/expression/EnumArg.java +++ b/core/src/main/java/io/substrait/expression/EnumArg.java @@ -26,5 +26,9 @@ static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) { return ImmutableEnumArg.builder().value(Optional.of(option)).build(); } + static EnumArg of(String value) { + return ImmutableEnumArg.builder().value(Optional.of(value)).build(); + } + EnumArg UNSPECIFIED_ENUM_ARG = ImmutableEnumArg.builder().value(Optional.empty()).build(); } diff --git a/readme.md b/readme.md index d527584a9..3a18ac0d2 100644 --- a/readme.md +++ b/readme.md @@ -3,7 +3,7 @@ Substrait Java is a project that makes it easier to build [Substrait](https://substrait.io/) plans through Java. The project has two main parts: 1) **Core** is the module that supports building Substrait plans directly through Java. This is much easier than manipulating the Substrait protobuf directly. It has no direct support for going from SQL to Substrait (that's covered by the second part) 2) **Isthmus** is the module that allows going from SQL to a Substrait plan. Both Java APIs and a top level script for conversion are present. Not all SQL is supported yet by this module, but a lot is. For example, all of the TPC-H queries and all but a few of the TPC-DS queries are translatable. -3) **Spark** is the module that provides an API for translating a Substrait plan to and from a Spark query plan. The most commonly used logical relations are supported, including those generated from all of the TPC-H queries, but there are currently some gaps in support that prevent all of the TPC-DS queries from being translatable. +3) **Spark** is the module that provides an API for translating a Substrait plan to and from a Spark query plan. The most commonly used logical relations and functions are supported, including those generated from all of the TPC-H and TCP-DS queries. ## Building After you've cloned the project through git, Substrait Java is built with a tool called [Gradle](https://gradle.org/). To build, execute the following: diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml index 9b509a577..93435680b 100644 --- a/spark/src/main/resources/spark.yml +++ b/spark/src/main/resources/spark.yml @@ -15,13 +15,6 @@ %YAML 1.2 --- scalar_functions: - - - name: year - description: Returns the year component of the date/timestamp - impls: - - args: - - value: date - return: i32 - name: unscaled description: >- @@ -41,6 +34,16 @@ scalar_functions: - args: - value: i64 return: DECIMAL + - name: add + description: >- + Adds days to a date + impls: + - args: + - name: start_date + value: date + - name: days + value: i32 + return: date - name: shift_right description: >- Bitwise (signed) shift right. diff --git a/spark/src/main/scala/io/substrait/spark/expression/Enum.scala b/spark/src/main/scala/io/substrait/spark/expression/Enum.scala new file mode 100644 index 000000000..eaf7d4411 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/Enum.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Unevaluable} +import org.apache.spark.sql.types.{DataType, NullType} + +/** + * For internal use only. This represents the equivalent of a Substrait enum parameter type for use + * during conversion. It must not become part of a final Spark logical plan. + * + * @param value + * The enum string value. + */ +case class Enum(value: String) extends LeafExpression with Unevaluable { + override def nullable: Boolean = false + + override def dataType: DataType = NullType + + override def equals(that: Any): Boolean = that match { + case Enum(other) => other == value + case _ => false + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala index 5c0f72692..1432396bb 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DataType import com.google.common.collect.{ArrayListMultimap, Multimap} import io.substrait.`type`.Type -import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg} +import io.substrait.expression.{EnumArg, Expression => SExpression, ExpressionCreator, FunctionArg} import io.substrait.expression.Expression.FailureBehavior import io.substrait.extension.SimpleExtension import io.substrait.function.{ParameterizedType, ToTypeString} @@ -93,14 +93,28 @@ abstract class FunctionConverter[F <: SimpleExtension.Function, T](functions: Se (matcherMap, keyMap) } - def getSparkExpressionFromSubstraitFunc(key: String, outputType: Type): Option[Sig] = { - val sigs = substraitFuncKeyToSig.get(key) - sigs.size() match { - case 0 => None - case 1 => Some(sigs.iterator().next()) - case _ => None + def getSparkExpressionFromSubstraitFunc( + key: String, + args: Seq[Expression]): Option[Expression] = { + val candidates = substraitFuncKeyToSig.get(key).asScala.toList + val sigs = if (candidates.length > 1) { + // attempt to disambiguate with the key (if it's been set) + val specific = candidates.filter { + case SpecialSig(_, _, Some(sig), _) if sig == key => true + case _ => false + } + if (specific.nonEmpty) { + specific + } else { + // no matching signature, so select the generic one(s) + candidates + } + } else { + candidates } + sigs.headOption.map(sig => sig.makeCall(args)) } + private def createFinder(name: String, functions: Seq[F]): FunctionFinder[F, T] = { new FunctionFinder[F, T]( name, @@ -237,10 +251,14 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( val singularInputType: Option[SingularArgumentMatcher[F]], val parent: FunctionConverter[F, T]) { - def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = { - val opTypes = operands.map(_.getType) + def attemptMatch(expression: Expression, operands: Seq[FunctionArg]): Option[T] = { val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable) - val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE)) + + val opTypesStr = operands.map { + case e: SExpression => e.getType.accept(ToTypeString.INSTANCE) + case t: Type => t.accept(ToTypeString.INSTANCE) + case _: EnumArg => "req" + } val possibleKeys = Util.crossProduct(opTypesStr.map(s => Seq(s))).map(list => list.mkString("_")) @@ -251,11 +269,11 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( if (operands.isEmpty) { val variant = directMap(name + ":") - variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) + // TODO validate the output type Option(parent.generateBinding(expression, variant, operands, outputType)) } else if (directMatchKey.isDefined) { val variant = directMap(directMatchKey.get) - variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) + // TODO validate the output type val funcArgs: Seq[FunctionArg] = operands Option(parent.generateBinding(expression, variant, funcArgs, outputType)) } else if (singularInputType.isDefined) { @@ -277,9 +295,7 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( .map( declaration => { val coercedArgs = coerceArguments(operands, leastRestrictiveSubstraitT) - declaration.validateOutputType( - JavaConverters.bufferAsJavaList(coercedArgs.toBuffer), - outputType) + // TODO validate the output type val funcArgs: Seq[FunctionArg] = coercedArgs parent.generateBinding(expression, declaration, funcArgs, outputType) }) @@ -293,14 +309,15 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( * Coerced types according to an expected output type. Coercion is only done for type mismatches, * not for nullability or parameter mismatches. */ - private def coerceArguments(arguments: Seq[SExpression], t: Type): Seq[SExpression] = { - arguments.map( - a => { + private def coerceArguments(arguments: Seq[FunctionArg], t: Type): Seq[FunctionArg] = { + arguments.map { + case a: SExpression => if (FunctionFinder.isMatch(t, a.getType)) { a } else { ExpressionCreator.cast(t, a, FailureBehavior.THROW_EXCEPTION) } - }) + case other => other + } } } diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index 157fe412e..d3efae4ad 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -19,19 +19,75 @@ package io.substrait.spark.expression import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType} + +import io.substrait.utils.Util import scala.reflect.ClassTag -case class Sig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression) { - def makeCall(args: Seq[Expression]): Expression = +trait Sig { + def name: String + def expClass: Class[_] + def makeCall(args: Seq[Expression]): Expression +} + +case class GenericSig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression) + extends Sig { + override def makeCall(args: Seq[Expression]): Expression = { builder(args) + } +} + +case class SpecialSig( + expClass: Class[_], + name: String, + key: Option[String], + builder: Seq[Expression] => Expression) + extends Sig { + override def makeCall(args: Seq[Expression]): Expression = { + builder(args) + } +} + +object DateFunction { + def unapply(e: Expression): Option[Seq[Expression]] = e match { + case DateAdd(startDate, days) => Some(Seq(startDate, days)) + // The following map to the Substrait `extract` function. + case Year(date) => Some(Seq(Enum("YEAR"), date)) + case Quarter(date) => Some(Seq(Enum("QUARTER"), Enum("ONE"), date)) + case Month(date) => Some(Seq(Enum("MONTH"), Enum("ONE"), date)) + case DayOfMonth(date) => Some(Seq(Enum("DAY"), Enum("ONE"), date)) + case _ => None + } + + def unapply(name_args: (String, Seq[Expression])): Option[Expression] = name_args match { + case ("add:date_i32", Seq(startDate, days)) => Some(DateAdd(startDate, days)) + case ("extract", Seq(Enum("YEAR"), date)) => Some(Year(date)) + case ("extract", Seq(Enum("QUARTER"), Enum("ONE"), date)) => Some(Quarter(date)) + case ("extract", Seq(Enum("MONTH"), Enum("ONE"), date)) => Some(Month(date)) + case ("extract", Seq(Enum("DAY"), Enum("ONE"), date)) => Some(DayOfMonth(date)) + case _ => None + } } class FunctionMappings { - private def s[T <: Expression: ClassTag](name: String): Sig = { + private def s[T <: Expression: ClassTag](name: String): GenericSig = { val builder = FunctionRegistryBase.build[T](name, None)._2 - Sig(scala.reflect.classTag[T].runtimeClass, name, builder) + GenericSig(scala.reflect.classTag[T].runtimeClass, name, builder) + } + + private def ss[T <: Expression: ClassTag](signature: String): SpecialSig = { + val (name, key) = if (signature.contains(":")) { + (signature.split(':').head, Some(signature)) + } else { + (signature, None) + } + val builder = (args: Seq[Expression]) => + (signature, args) match { + case DateFunction(expr) => expr + } + SpecialSig(scala.reflect.classTag[T].runtimeClass, name, key, builder) } val SCALAR_SIGS: Seq[Sig] = Seq( @@ -82,12 +138,18 @@ class FunctionMappings { s[Lower]("lower"), s[Concat]("concat"), s[Coalesce]("coalesce"), - s[Year]("year"), s[ShiftRight]("shift_right"), s[BitwiseAnd]("bitwise_and"), s[BitwiseOr]("bitwise_or"), s[BitwiseXor]("bitwise_xor"), + // date/time functions require special handling + ss[DateAdd]("add:date_i32"), + ss[Year]("extract"), + ss[Quarter]("extract"), + ss[Month]("extract"), + ss[DayOfMonth]("extract"), + // internal s[MakeDecimal]("make_decimal"), s[UnscaledValue]("unscaled") @@ -115,11 +177,6 @@ class FunctionMappings { s[Lag]("lag"), s[NthValue]("nth_value") ) - - lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap - lazy val aggregate_functions_map: Map[Class[_], Sig] = - AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap - lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap } object FunctionMappings extends FunctionMappings diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala index dac4873c3..d1ea06001 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala @@ -40,7 +40,7 @@ abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVar .build() } - def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = { + def convert(expression: Expression, operands: Seq[FunctionArg]): Option[SExpression] = { Option(signatures.get(expression.getClass)) .flatMap(m => m.attemptMatch(expression, operands)) } diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 9d3fea1ea..8c985b645 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -22,13 +22,14 @@ import io.substrait.spark.logical.ToLogicalPlan import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, ScalarSubquery} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DateType, Decimal} +import org.apache.spark.sql.types.Decimal import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} -import io.substrait.expression.{Expression => SExpression} +import io.substrait.expression.{EnumArg, Expression => SExpression} +import io.substrait.extension.SimpleExtension import io.substrait.util.DecimalUtil import io.substrait.utils.Util @@ -162,10 +163,11 @@ class ToSparkExpression( override def visit(expr: SExpression.Cast): Expression = { val childExp = expr.input().accept(this) val tt = ToSparkType.convert(expr.getType) - val tz = childExp.dataType match { - case DateType => Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)) - case _ => None - } + val tz = + if (Cast.needsTimeZone(childExp.dataType, tt)) + Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)) + else + None Cast(childExp, tt, tz) } @@ -219,12 +221,19 @@ class ToSparkExpression( } } + override def visitEnumArg( + fnDef: SimpleExtension.Function, + argIdx: Int, + e: EnumArg): Expression = { + Enum(e.value.orElse("")) + } + override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = { val eArgs = expr.arguments().asScala val args = eArgs.zipWithIndex.map { case (arg, i) => arg.accept(expr.declaration(), i, this) - } + }.toList expr.declaration.name match { case "make_decimal" if expr.declaration.uri == SparkExtension.uri => @@ -238,8 +247,7 @@ class ToSparkExpression( } case _ => scalarFunctionConverter - .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType()) - .flatMap(sig => Option(sig.makeCall(args))) + .getSparkExpressionFromSubstraitFunc(expr.declaration.key, args) .getOrElse({ val msg = String.format( "Unable to convert scalar function %s(%s).", diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala index b84ff46c6..bd9ac7fd8 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.substrait.SparkTypeUtil -import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FieldReference, ImmutableExpression} +import io.substrait.expression.{EnumArg, Expression => SExpression, ExpressionCreator, FieldReference, ImmutableEnumArg, ImmutableExpression} import io.substrait.expression.Expression.FailureBehavior import io.substrait.utils.Util +import java.util.Optional + import scala.collection.JavaConverters.asJavaIterableConverter /** The builder to generate substrait expressions from catalyst expressions. */ @@ -33,6 +35,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { object ScalarFunction { def unapply(e: Expression): Option[Seq[Expression]] = e match { + case DateFunction(arguments) => Some(arguments) case BinaryExpression(left, right) => Some(Seq(left, right)) case UnaryExpression(child) => Some(Seq(child)) case t: TernaryExpression => Some(Seq(t.first, t.second, t.third)) @@ -196,7 +199,10 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v))) case scalar @ ScalarFunction(children) => Util - .seqToOption(children.map(translateUp)) + .seqToOption(children.map { + case Enum(value) => Some(EnumArg.of(value)) + case child: Expression => translateUp(child) + }) .flatMap(toScalarFunction.convert(scalar, _)) case p: PlanExpression[_] => translateSubQuery(p) case in: InSubquery => translateInSubquery(in) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index ab019754f..e64634f81 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -71,8 +71,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } val aggregateFunction = SparkExtension.toAggregateFunction - .getSparkExpressionFromSubstraitFunc(function.declaration.key, function.outputType) - .map(sig => sig.makeCall(arguments)) + .getSparkExpressionFromSubstraitFunc(function.declaration.key, arguments) .map(_.asInstanceOf[AggregateFunction]) .getOrElse({ val msg = String.format( @@ -137,8 +136,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] arg.accept(func.declaration(), i, expressionConverter) } val windowFunction = SparkExtension.toWindowFunction - .getSparkExpressionFromSubstraitFunc(func.declaration.key, func.outputType) - .map(sig => sig.makeCall(arguments)) + .getSparkExpressionFromSubstraitFunc(func.declaration.key, arguments) .map { case win: WindowFunction => win case agg: AggregateFunction => diff --git a/spark/src/test/scala/io/substrait/spark/DateTimeSuite.scala b/spark/src/test/scala/io/substrait/spark/DateTimeSuite.scala new file mode 100644 index 000000000..e3b911765 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/DateTimeSuite.scala @@ -0,0 +1,30 @@ +package io.substrait.spark + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSparkSession + +class DateTimeSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + } + + test("date_add") { + val qry = + "select cast(d AS DATE) + interval 5 days from (values ('2025-03-27'), ('2025-01-02')) as table(d)" + assertSqlSubstraitRelRoundTrip(qry) + } + + test("date_sub") { + val qry = + "select cast(d AS DATE) - interval 5 days from (values ('2025-03-27'), ('2025-01-02')) as table(d)" + assertSqlSubstraitRelRoundTrip(qry) + } + + test("extract_year_month") { + val qry = "select year(cast(d AS DATE)), month(cast(d AS DATE)) " + + "from (values ('2025-03-27'), ('2025-01-02')) as table(d)" + assertSqlSubstraitRelRoundTrip(qry) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 8676ac9d7..3613f48b5 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.internal.SQLConf class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { - private val runAllQueriesIncludeFailed = false override def beforeAll(): Unit = { super.beforeAll() sparkContext.setLogLevel("WARN") @@ -31,22 +30,10 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { spark.conf.set("spark.sql.readSideCharPadding", "false") } - // spotless:off - val failingSQL: Set[String] = Set( - "q72" //requires implementation of date_add() - ) - // spotless:on - tpcdsQueries.foreach { q => - if (runAllQueriesIncludeFailed || !failingSQL.contains(q)) { - test(s"check simplified (tpcds-v1.4/$q)") { - testQuery("tpcds", q) - } - } else { - ignore(s"check simplified (tpcds-v1.4/$q)") { - testQuery("tpcds", q) - } + test(s"check simplified (tpcds-v1.4/$q)") { + testQuery("tpcds", q) } } diff --git a/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala b/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala index e855e0d47..e3b20b1e5 100644 --- a/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala +++ b/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala @@ -22,11 +22,7 @@ import org.apache.spark.SparkFunSuite class YamlTest extends SparkFunSuite { - test("has_year_definition") { - assert( - SparkExtension.SparkScalarFunctions - .map(f => f.key()) - .exists(p => p.equals("year:date"))) + test("has_unscaled_definition") { assert( SparkExtension.SparkScalarFunctions .map(f => f.key())