Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public TypeProtoConverter getTypeProtoConverter() {
}

public io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) {
return expression.accept(this, null);
return expression.accept(this, EmptyVisitationContext.INSTANCE);
}

public List<io.substrait.proto.Expression> toProto(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.proto.ExtendedExpression;
import io.substrait.relation.AggregateFunctionProtoConverter;
import io.substrait.type.proto.TypeProtoConverter;
import io.substrait.util.EmptyVisitationContext;

/**
* Converts from {@link io.substrait.extendedexpression.ExtendedExpression} to {@link
Expand All @@ -27,7 +28,7 @@ public ExtendedExpression toProto(
if (expressionReference
instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et) {
io.substrait.proto.Expression expressionProto =
et.getExpression().accept(expressionProtoConverter, null);
et.getExpression().accept(expressionProtoConverter, EmptyVisitationContext.INSTANCE);
ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setExpression(expressionProto)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.AggregateFunction;
import io.substrait.type.proto.TypeProtoConverter;
import io.substrait.util.EmptyVisitationContext;
import java.util.stream.IntStream;

/**
Expand Down Expand Up @@ -34,7 +35,10 @@ public AggregateFunction toProto(Aggregate.Measure measure) {
.setOutputType(measure.getFunction().getType().accept(typeProtoConverter))
.addAllArguments(
IntStream.range(0, args.size())
.mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null))
.mapToObj(
i ->
args.get(i)
.accept(aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE))
.collect(java.util.stream.Collectors.toList()))
.setFunctionReference(
functionCollector.getFunctionReference(measure.getFunction().declaration()))
Expand Down
13 changes: 10 additions & 3 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public io.substrait.proto.RelRoot toProto(Plan.Root relRoot) {
}

public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) {
return rel.accept(this, null);
return rel.accept(this, EmptyVisitationContext.INSTANCE);
}

protected io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) {
Expand Down Expand Up @@ -136,7 +136,10 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) {
.setOutputType(toProto(measure.getFunction().getType()))
.addAllArguments(
IntStream.range(0, args.size())
.mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null))
.mapToObj(
i ->
args.get(i)
.accept(aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE))
.collect(Collectors.toList()))
.addAllSorts(toProtoS(measure.getFunction().sort()))
.setFunctionReference(
Expand Down Expand Up @@ -463,7 +466,11 @@ private List<ConsistentPartitionWindowRel.WindowRelFunction> toProtoWindowRelFun

var arguments =
IntStream.range(0, args.size())
.mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null))
.mapToObj(
i ->
args.get(i)
.accept(
aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE))
.collect(Collectors.toList());
var options =
f.options().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.util.EmptyVisitationContext;
import java.util.Arrays;
import org.junit.jupiter.api.Test;

Expand All @@ -26,7 +27,7 @@ void ifThenNotNullable() {

var to = new ExpressionProtoConverter(null, null);
var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter);
assertEquals(ifRel, from.from(ifRel.accept(to, null)));
assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE)));
}

@Test
Expand All @@ -40,6 +41,6 @@ void ifThenNullable() {

var to = new ExpressionProtoConverter(null, null);
var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter);
assertEquals(ifRel, from.from(ifRel.accept(to, null)));
assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.util.EmptyVisitationContext;
import java.math.BigDecimal;
import org.junit.jupiter.api.Test;

Expand All @@ -17,6 +18,6 @@ void decimal() {
var val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2);
var to = new ExpressionProtoConverter(null, null);
var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter);
assertEquals(val, from.from(val.accept(to, null)));
assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private TableGatherer() {
*/
public static Map<List<String>, NamedStruct> gatherTables(Rel rootRel) {
var visitor = new TableGatherer();
rootRel.accept(visitor, null);
rootRel.accept(visitor, EmptyVisitationContext.INSTANCE);
return visitor.tableMap;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private NamedStructGatherer() {

public static Map<List<String>, NamedStruct> gatherTables(Rel rel) {
var visitor = new NamedStructGatherer();
rel.accept(visitor, null);
rel.accept(visitor, EmptyVisitationContext.INSTANCE);
return visitor.tableMap;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ public Optional<Rel> visit(Cross cross, EmptyVisitationContext context)
"orders" o
""",
new SqlToSubstrait(featureBoard));
plan1.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor, null));
plan1
.getRoots()
.forEach(
t -> t.getInput().accept(crossJoinCountingVisitor, EmptyVisitationContext.INSTANCE));
assertEquals(1, counter[0]);

Plan plan2 =
Expand All @@ -101,7 +104,10 @@ public Optional<Rel> visit(Cross cross, EmptyVisitationContext context)
"orders" o
""",
new SqlToSubstrait(featureBoard));
plan2.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor, null));
plan2
.getRoots()
.forEach(
t -> t.getInput().accept(crossJoinCountingVisitor, EmptyVisitationContext.INSTANCE));
assertEquals(2, counter[0]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ public void replaceCountDistinctsInUnion() throws IOException, SqlParseException
private static class HasTableReference {
public boolean hasTableReference(Plan plan, String name) {
HasTableReferenceVisitor visitor = new HasTableReferenceVisitor(Arrays.asList(name));
plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null));
plan.getRoots().stream()
.forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE));
return (visitor.hasTableReference());
}

Expand Down Expand Up @@ -190,7 +191,8 @@ private static class CountCountDistinct {

public int getCountDistincts(Plan plan) {
CountCountDistinctVisitor visitor = new CountCountDistinctVisitor();
plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null));
plan.getRoots().stream()
.forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE));
return visitor.getCountDistincts();
}

Expand Down Expand Up @@ -221,7 +223,8 @@ private static class CountApproxCountDistinct {

public int getApproxCountDistincts(Plan plan) {
CountCountDistinctVisitor visitor = new CountCountDistinctVisitor();
plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null));
plan.getRoots().stream()
.forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE));
return visitor.getApproxCountDistincts();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.type.Type;
import io.substrait.util.EmptyVisitationContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -62,7 +63,8 @@ void extensionMultiRelDetailTest() {
void roundtrip(Rel pojo1) {
// Substrait POJO 1 -> Substrait Proto
io.substrait.proto.Rel proto =
pojo1.accept(new RelProtoConverter(new ExtensionCollector()), null);
pojo1.accept(
new RelProtoConverter(new ExtensionCollector()), EmptyVisitationContext.INSTANCE);

// Substrait Proto -> Substrait POJO 2
var pojo2 = (new CustomProtoRelConverter(new ExtensionCollector())).from(proto);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
}

def apply(rel: Rel, maxFields: Int): String = {
rel.accept(this, null)
rel.accept(this, EmptyVisitationContext.INSTANCE)
}

override def visit(fetch: Fetch, context: EmptyVisitationContext): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
val function = measure.getFunction
var arguments = function.arguments().asScala.zipWithIndex.map {
case (arg, i) =>
arg.accept(function.declaration(), i, expressionConverter, null)
arg.accept(function.declaration(), i, expressionConverter, EmptyVisitationContext.INSTANCE)
}
if (function.declaration.name == "count" && function.arguments.size == 0) {
// HACK - count() needs to be rewritten as count(1)
Expand All @@ -92,7 +92,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
})

val filter = Option(measure.getPreMeasureFilter.orElse(null))
.map(_.accept(expressionConverter, null))
.map(_.accept(expressionConverter, EmptyVisitationContext.INSTANCE))

AggregateExpression(
aggregateFunction,
Expand Down Expand Up @@ -213,7 +213,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
}

private def toSortOrder(sortField: SExpression.SortField): SortOrder = {
val expression = sortField.expr().accept(expressionConverter, null)
val expression = sortField.expr().accept(expressionConverter, EmptyVisitationContext.INSTANCE)
val (direction, nullOrdering) = sortField.direction() match {
case SExpression.SortDirection.ASC_NULLS_FIRST => (Ascending, NullsFirst)
case SExpression.SortDirection.DESC_NULLS_FIRST => (Descending, NullsFirst)
Expand Down Expand Up @@ -449,7 +449,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
}

def convert(rel: relation.Rel): LogicalPlan = {
val logicalPlan = rel.accept(this, null)
val logicalPlan = rel.accept(this, EmptyVisitationContext.INSTANCE)
require(logicalPlan.resolved)
logicalPlan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import io.substrait.relation.RelProtoConverter
import io.substrait.relation.Set.SetOp
import io.substrait.relation.files.{FileFormat, FileOrFiles}
import io.substrait.relation.files.FileOrFiles.PathType
import io.substrait.util.EmptyVisitationContext
import io.substrait.utils.Util

import java.util
Expand Down Expand Up @@ -575,7 +576,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
proto.PlanRel
.newBuilder()
.setRel(substraitRel
.accept(relProtoConverter, null))
.accept(relProtoConverter, EmptyVisitationContext.INSTANCE))
)
extensionCollector.addExtensionsToPlan(builder)
builder.build().toByteArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import io.substrait.extension.ExtensionCollector
import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter}
import io.substrait.proto
import io.substrait.relation.{ProtoRelConverter, RelProtoConverter}
import io.substrait.util.EmptyVisitationContext
import org.scalactic.Equality
import org.scalactic.source.Position
import org.scalatest.Succeeded
Expand Down Expand Up @@ -72,7 +73,7 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>

// convert substrait back to spark plan
val toLogicalPlan = new ToLogicalPlan(spark);
val sparkPlan2 = substraitRel2.accept(toLogicalPlan, null)
val sparkPlan2 = substraitRel2.accept(toLogicalPlan, EmptyVisitationContext.INSTANCE)
require(sparkPlan2.resolved)

// and back to substrait again
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.util.EmptyVisitationContext

import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}

class TypesAndLiteralsSuite extends SparkFunSuite {
Expand Down Expand Up @@ -101,7 +103,9 @@ class TypesAndLiteralsSuite extends SparkFunSuite {
l => {
test(s"test literal: $l (${l.dataType})") {
val substraitLiteral = ToSubstraitLiteral.convert(l).get
val sparkLiteral = substraitLiteral.accept(toSparkExpression, null).asInstanceOf[Literal]
val sparkLiteral = substraitLiteral
.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)
.asInstanceOf[Literal]

println("Before: " + l + " " + l.dataType)
println("After: " + sparkLiteral + " " + sparkLiteral.dataType)
Expand All @@ -118,7 +122,9 @@ class TypesAndLiteralsSuite extends SparkFunSuite {
MapType(IntegerType, StringType, valueContainsNull = false))

val substraitLiteral = ToSubstraitLiteral.convert(l).get
val sparkLiteral = substraitLiteral.accept(toSparkExpression, null).asInstanceOf[Literal]
val sparkLiteral = substraitLiteral
.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)
.asInstanceOf[Literal]

println("Before: " + l + " " + l.dataType)
println("After: " + sparkLiteral + " " + sparkLiteral.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.substrait.spark.SparkExtension
import org.apache.spark.sql.catalyst.expressions.Expression

import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.EmptyVisitationContext
import org.scalatest.Assertions.assertResult

trait SubstraitExpressionTestBase {
Expand Down Expand Up @@ -48,7 +49,8 @@ trait SubstraitExpressionTestBase {
f(substraitExp)

if (bidirectional) {
val convertedExpression = substraitExp.accept(toSparkExpression, null)
val convertedExpression =
substraitExp.accept(toSparkExpression, EmptyVisitationContext.INSTANCE)
assertResult(expression)(convertedExpression)
}
}
Expand Down
Loading