Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/java/io/substrait/expression/EnumArg.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@ static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) {
return ImmutableEnumArg.builder().value(Optional.of(option)).build();
}

static EnumArg of(String value) {
return ImmutableEnumArg.builder().value(Optional.of(value)).build();
}

EnumArg UNSPECIFIED_ENUM_ARG = ImmutableEnumArg.builder().value(Optional.empty()).build();
}
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Substrait Java is a project that makes it easier to build [Substrait](https://substrait.io/) plans through Java. The project has two main parts:
1) **Core** is the module that supports building Substrait plans directly through Java. This is much easier than manipulating the Substrait protobuf directly. It has no direct support for going from SQL to Substrait (that's covered by the second part)
2) **Isthmus** is the module that allows going from SQL to a Substrait plan. Both Java APIs and a top level script for conversion are present. Not all SQL is supported yet by this module, but a lot is. For example, all of the TPC-H queries and all but a few of the TPC-DS queries are translatable.
3) **Spark** is the module that provides an API for translating a Substrait plan to and from a Spark query plan. The most commonly used logical relations are supported, including those generated from all of the TPC-H queries, but there are currently some gaps in support that prevent all of the TPC-DS queries from being translatable.
3) **Spark** is the module that provides an API for translating a Substrait plan to and from a Spark query plan. The most commonly used logical relations and functions are supported, including those generated from all of the TPC-H and TCP-DS queries.

## Building
After you've cloned the project through git, Substrait Java is built with a tool called [Gradle](https://gradle.org/). To build, execute the following:
Expand Down
17 changes: 10 additions & 7 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
%YAML 1.2
---
scalar_functions:
-
name: year
description: Returns the year component of the date/timestamp
impls:
- args:
- value: date
return: i32
-
name: unscaled
description: >-
Expand All @@ -41,6 +34,16 @@ scalar_functions:
- args:
- value: i64
return: DECIMAL<P,S>
- name: add
description: >-
Adds days to a date
impls:
- args:
- name: start_date
value: date
- name: days
value: i32
return: date
- name: shift_right
description: >-
Bitwise (signed) shift right.
Expand Down
39 changes: 39 additions & 0 deletions spark/src/main/scala/io/substrait/spark/expression/Enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.substrait.spark.expression

import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Unevaluable}
import org.apache.spark.sql.types.{DataType, NullType}

/**
* For internal use only. This represents the equivalent of a Substrait enum parameter type for use
* during conversion. It must not become part of a final Spark logical plan.
*
* @param value
* The enum string value.
*/
case class Enum(value: String) extends LeafExpression with Unevaluable {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, mind adding a docstring to explain what this is for though?

override def nullable: Boolean = false

override def dataType: DataType = NullType

override def equals(that: Any): Boolean = that match {
case Enum(other) => other == value
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DataType

import com.google.common.collect.{ArrayListMultimap, Multimap}
import io.substrait.`type`.Type
import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg}
import io.substrait.expression.{EnumArg, Expression => SExpression, ExpressionCreator, FunctionArg}
import io.substrait.expression.Expression.FailureBehavior
import io.substrait.extension.SimpleExtension
import io.substrait.function.{ParameterizedType, ToTypeString}
Expand Down Expand Up @@ -93,14 +93,28 @@ abstract class FunctionConverter[F <: SimpleExtension.Function, T](functions: Se
(matcherMap, keyMap)
}

def getSparkExpressionFromSubstraitFunc(key: String, outputType: Type): Option[Sig] = {
val sigs = substraitFuncKeyToSig.get(key)
sigs.size() match {
case 0 => None
case 1 => Some(sigs.iterator().next())
case _ => None
def getSparkExpressionFromSubstraitFunc(
key: String,
args: Seq[Expression]): Option[Expression] = {
val candidates = substraitFuncKeyToSig.get(key).asScala.toList
val sigs = if (candidates.length > 1) {
// attempt to disambiguate with the key (if it's been set)
val specific = candidates.filter {
case SpecialSig(_, _, Some(sig), _) if sig == key => true
case _ => false
}
if (specific.nonEmpty) {
specific
} else {
// no matching signature, so select the generic one(s)
candidates
}
} else {
candidates
}
sigs.headOption.map(sig => sig.makeCall(args))
}

private def createFinder(name: String, functions: Seq[F]): FunctionFinder[F, T] = {
new FunctionFinder[F, T](
name,
Expand Down Expand Up @@ -237,10 +251,14 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
val singularInputType: Option[SingularArgumentMatcher[F]],
val parent: FunctionConverter[F, T]) {

def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = {
val opTypes = operands.map(_.getType)
def attemptMatch(expression: Expression, operands: Seq[FunctionArg]): Option[T] = {
val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable)
val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE))

val opTypesStr = operands.map {
case e: SExpression => e.getType.accept(ToTypeString.INSTANCE)
case t: Type => t.accept(ToTypeString.INSTANCE)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know if we have any tests for this (t: Type) case?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this gets driven by the last test in DateTimeSuite.scala and by three of the TCP-H tests (that extract the year component - this was previously handled by an internal definition in spark.yaml)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't the use the EnumArg case, not the Type case?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry. The Type case is there for completeness. As far as I can see, there are no currently supported functions that use type arguments. I suppose we could throw an unsupported exception in this case, if you'd prefer.

case _: EnumArg => "req"
}

val possibleKeys =
Util.crossProduct(opTypesStr.map(s => Seq(s))).map(list => list.mkString("_"))
Expand All @@ -251,11 +269,11 @@ class FunctionFinder[F <: SimpleExtension.Function, T](

if (operands.isEmpty) {
val variant = directMap(name + ":")
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
// TODO validate the output type
Option(parent.generateBinding(expression, variant, operands, outputType))
} else if (directMatchKey.isDefined) {
val variant = directMap(directMatchKey.get)
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
// TODO validate the output type
val funcArgs: Seq[FunctionArg] = operands
Option(parent.generateBinding(expression, variant, funcArgs, outputType))
} else if (singularInputType.isDefined) {
Expand All @@ -277,9 +295,7 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
.map(
declaration => {
val coercedArgs = coerceArguments(operands, leastRestrictiveSubstraitT)
declaration.validateOutputType(
JavaConverters.bufferAsJavaList(coercedArgs.toBuffer),
outputType)
// TODO validate the output type
val funcArgs: Seq[FunctionArg] = coercedArgs
parent.generateBinding(expression, declaration, funcArgs, outputType)
})
Expand All @@ -293,14 +309,15 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
* not for nullability or parameter mismatches.
*/
private def coerceArguments(arguments: Seq[SExpression], t: Type): Seq[SExpression] = {
arguments.map(
a => {
private def coerceArguments(arguments: Seq[FunctionArg], t: Type): Seq[FunctionArg] = {
arguments.map {
case a: SExpression =>
if (FunctionFinder.isMatch(t, a.getType)) {
a
} else {
ExpressionCreator.cast(t, a, FailureBehavior.THROW_EXCEPTION)
}
})
case other => other
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,75 @@ package io.substrait.spark.expression
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType}

import io.substrait.utils.Util

import scala.reflect.ClassTag

case class Sig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression) {
def makeCall(args: Seq[Expression]): Expression =
trait Sig {
def name: String
def expClass: Class[_]
def makeCall(args: Seq[Expression]): Expression
}

case class GenericSig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression)
extends Sig {
override def makeCall(args: Seq[Expression]): Expression = {
builder(args)
}
}

case class SpecialSig(
expClass: Class[_],
name: String,
key: Option[String],
builder: Seq[Expression] => Expression)
extends Sig {
override def makeCall(args: Seq[Expression]): Expression = {
builder(args)
}
}

object DateFunction {
def unapply(e: Expression): Option[Seq[Expression]] = e match {
case DateAdd(startDate, days) => Some(Seq(startDate, days))
// The following map to the Substrait `extract` function.
case Year(date) => Some(Seq(Enum("YEAR"), date))
case Quarter(date) => Some(Seq(Enum("QUARTER"), Enum("ONE"), date))
case Month(date) => Some(Seq(Enum("MONTH"), Enum("ONE"), date))
case DayOfMonth(date) => Some(Seq(Enum("DAY"), Enum("ONE"), date))
case _ => None
}

def unapply(name_args: (String, Seq[Expression])): Option[Expression] = name_args match {
case ("add:date_i32", Seq(startDate, days)) => Some(DateAdd(startDate, days))
case ("extract", Seq(Enum("YEAR"), date)) => Some(Year(date))
case ("extract", Seq(Enum("QUARTER"), Enum("ONE"), date)) => Some(Quarter(date))
case ("extract", Seq(Enum("MONTH"), Enum("ONE"), date)) => Some(Month(date))
case ("extract", Seq(Enum("DAY"), Enum("ONE"), date)) => Some(DayOfMonth(date))
case _ => None
}
}

class FunctionMappings {

private def s[T <: Expression: ClassTag](name: String): Sig = {
private def s[T <: Expression: ClassTag](name: String): GenericSig = {
val builder = FunctionRegistryBase.build[T](name, None)._2
Sig(scala.reflect.classTag[T].runtimeClass, name, builder)
GenericSig(scala.reflect.classTag[T].runtimeClass, name, builder)
}

private def ss[T <: Expression: ClassTag](signature: String): SpecialSig = {
val (name, key) = if (signature.contains(":")) {
(signature.split(':').head, Some(signature))
} else {
(signature, None)
}
val builder = (args: Seq[Expression]) =>
(signature, args) match {
case DateFunction(expr) => expr
}
SpecialSig(scala.reflect.classTag[T].runtimeClass, name, key, builder)
}

val SCALAR_SIGS: Seq[Sig] = Seq(
Expand Down Expand Up @@ -82,12 +138,18 @@ class FunctionMappings {
s[Lower]("lower"),
s[Concat]("concat"),
s[Coalesce]("coalesce"),
s[Year]("year"),
s[ShiftRight]("shift_right"),
s[BitwiseAnd]("bitwise_and"),
s[BitwiseOr]("bitwise_or"),
s[BitwiseXor]("bitwise_xor"),

// date/time functions require special handling
ss[DateAdd]("add:date_i32"),
ss[Year]("extract"),
ss[Quarter]("extract"),
ss[Month]("extract"),
ss[DayOfMonth]("extract"),

// internal
s[MakeDecimal]("make_decimal"),
s[UnscaledValue]("unscaled")
Expand Down Expand Up @@ -115,11 +177,6 @@ class FunctionMappings {
s[Lag]("lag"),
s[NthValue]("nth_value")
)

lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap
lazy val aggregate_functions_map: Map[Class[_], Sig] =
AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap
lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap
}

object FunctionMappings extends FunctionMappings
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVar
.build()
}

def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = {
def convert(expression: Expression, operands: Seq[FunctionArg]): Option[SExpression] = {
Option(signatures.get(expression.getClass))
.flatMap(m => m.attemptMatch(expression, operands))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import io.substrait.spark.logical.ToLogicalPlan
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, Decimal}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.expression.{EnumArg, Expression => SExpression}
import io.substrait.extension.SimpleExtension
import io.substrait.util.DecimalUtil
import io.substrait.utils.Util

Expand Down Expand Up @@ -162,10 +163,11 @@ class ToSparkExpression(
override def visit(expr: SExpression.Cast): Expression = {
val childExp = expr.input().accept(this)
val tt = ToSparkType.convert(expr.getType)
val tz = childExp.dataType match {
case DateType => Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE))
case _ => None
}
val tz =
if (Cast.needsTimeZone(childExp.dataType, tt))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE))
else
None
Cast(childExp, tt, tz)
}

Expand Down Expand Up @@ -219,12 +221,19 @@ class ToSparkExpression(
}
}

override def visitEnumArg(
fnDef: SimpleExtension.Function,
argIdx: Int,
e: EnumArg): Expression = {
Enum(e.value.orElse(""))
}

override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = {
val eArgs = expr.arguments().asScala
val args = eArgs.zipWithIndex.map {
case (arg, i) =>
arg.accept(expr.declaration(), i, this)
}
}.toList

expr.declaration.name match {
case "make_decimal" if expr.declaration.uri == SparkExtension.uri =>
Expand All @@ -238,8 +247,7 @@ class ToSparkExpression(
}
case _ =>
scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getSparkExpressionFromSubstraitFunc(expr.declaration.key, args)
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
Expand Down
Loading
Loading