diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index a51a53de9..146243e18 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -47,7 +47,7 @@ public TypeProtoConverter getTypeProtoConverter() { } public io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) { - return expression.accept(this, null); + return expression.accept(this, EmptyVisitationContext.INSTANCE); } public List toProto( diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index bec48ec7a..00c340b53 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -6,6 +6,7 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.relation.AggregateFunctionProtoConverter; import io.substrait.type.proto.TypeProtoConverter; +import io.substrait.util.EmptyVisitationContext; /** * Converts from {@link io.substrait.extendedexpression.ExtendedExpression} to {@link @@ -27,7 +28,7 @@ public ExtendedExpression toProto( if (expressionReference instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et) { io.substrait.proto.Expression expressionProto = - et.getExpression().accept(expressionProtoConverter, null); + et.getExpression().accept(expressionProtoConverter, EmptyVisitationContext.INSTANCE); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java index b53bfc90a..7a9d0f569 100644 --- a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java @@ -5,6 +5,7 @@ import io.substrait.extension.ExtensionCollector; import io.substrait.proto.AggregateFunction; import io.substrait.type.proto.TypeProtoConverter; +import io.substrait.util.EmptyVisitationContext; import java.util.stream.IntStream; /** @@ -34,7 +35,10 @@ public AggregateFunction toProto(Aggregate.Measure measure) { .setOutputType(measure.getFunction().getType().accept(typeProtoConverter)) .addAllArguments( IntStream.range(0, args.size()) - .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null)) + .mapToObj( + i -> + args.get(i) + .accept(aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) .collect(java.util.stream.Collectors.toList())) .setFunctionReference( functionCollector.getFunctionReference(measure.getFunction().declaration())) diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index c746c48d3..f75d2ab9d 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -77,7 +77,7 @@ public io.substrait.proto.RelRoot toProto(Plan.Root relRoot) { } public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) { - return rel.accept(this, null); + return rel.accept(this, EmptyVisitationContext.INSTANCE); } protected io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) { @@ -136,7 +136,10 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .setOutputType(toProto(measure.getFunction().getType())) .addAllArguments( IntStream.range(0, args.size()) - .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null)) + .mapToObj( + i -> + args.get(i) + .accept(aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) .collect(Collectors.toList())) .addAllSorts(toProtoS(measure.getFunction().sort())) .setFunctionReference( @@ -463,7 +466,11 @@ private List toProtoWindowRelFun var arguments = IntStream.range(0, args.size()) - .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null)) + .mapToObj( + i -> + args.get(i) + .accept( + aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) .collect(Collectors.toList()); var options = f.options().stream() diff --git a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java index f69f5d2fa..cfcdaf6fc 100644 --- a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java @@ -10,6 +10,7 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.util.EmptyVisitationContext; import java.util.Arrays; import org.junit.jupiter.api.Test; @@ -26,7 +27,7 @@ void ifThenNotNullable() { var to = new ExpressionProtoConverter(null, null); var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(ifRel, from.from(ifRel.accept(to, null))); + assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } @Test @@ -40,6 +41,6 @@ void ifThenNullable() { var to = new ExpressionProtoConverter(null, null); var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(ifRel, from.from(ifRel.accept(to, null))); + assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index 5e1d58cf1..b0f1b5fe3 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -7,6 +7,7 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.util.EmptyVisitationContext; import java.math.BigDecimal; import org.junit.jupiter.api.Test; @@ -17,6 +18,6 @@ void decimal() { var val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); var to = new ExpressionProtoConverter(null, null); var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(val, from.from(val.accept(to, null))); + assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java index 2f6b2d6bd..c42a81717 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java @@ -79,7 +79,7 @@ private TableGatherer() { */ public static Map, NamedStruct> gatherTables(Rel rootRel) { var visitor = new TableGatherer(); - rootRel.accept(visitor, null); + rootRel.accept(visitor, EmptyVisitationContext.INSTANCE); return visitor.tableMap; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index 3e1e1f383..375f3fb9d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -178,7 +178,7 @@ private NamedStructGatherer() { public static Map, NamedStruct> gatherTables(Rel rel) { var visitor = new NamedStructGatherer(); - rel.accept(visitor, null); + rel.accept(visitor, EmptyVisitationContext.INSTANCE); return visitor.tableMap; } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index a80104a6d..48d13612f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -87,7 +87,10 @@ public Optional visit(Cross cross, EmptyVisitationContext context) "orders" o """, new SqlToSubstrait(featureBoard)); - plan1.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor, null)); + plan1 + .getRoots() + .forEach( + t -> t.getInput().accept(crossJoinCountingVisitor, EmptyVisitationContext.INSTANCE)); assertEquals(1, counter[0]); Plan plan2 = @@ -101,7 +104,10 @@ public Optional visit(Cross cross, EmptyVisitationContext context) "orders" o """, new SqlToSubstrait(featureBoard)); - plan2.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor, null)); + plan2 + .getRoots() + .forEach( + t -> t.getInput().accept(crossJoinCountingVisitor, EmptyVisitationContext.INSTANCE)); assertEquals(2, counter[0]); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index a11173a8a..50994df41 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -156,7 +156,8 @@ public void replaceCountDistinctsInUnion() throws IOException, SqlParseException private static class HasTableReference { public boolean hasTableReference(Plan plan, String name) { HasTableReferenceVisitor visitor = new HasTableReferenceVisitor(Arrays.asList(name)); - plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null)); + plan.getRoots().stream() + .forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE)); return (visitor.hasTableReference()); } @@ -190,7 +191,8 @@ private static class CountCountDistinct { public int getCountDistincts(Plan plan) { CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); - plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null)); + plan.getRoots().stream() + .forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE)); return visitor.getCountDistincts(); } @@ -221,7 +223,8 @@ private static class CountApproxCountDistinct { public int getApproxCountDistincts(Plan plan) { CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); - plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null)); + plan.getRoots().stream() + .forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE)); return visitor.getApproxCountDistincts(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java index db9eeca9b..b1ff2d46c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java @@ -18,6 +18,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; +import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -62,7 +63,8 @@ void extensionMultiRelDetailTest() { void roundtrip(Rel pojo1) { // Substrait POJO 1 -> Substrait Proto io.substrait.proto.Rel proto = - pojo1.accept(new RelProtoConverter(new ExtensionCollector()), null); + pojo1.accept( + new RelProtoConverter(new ExtensionCollector()), EmptyVisitationContext.INSTANCE); // Substrait Proto -> Substrait POJO 2 var pojo2 = (new CustomProtoRelConverter(new ExtensionCollector())).from(proto); diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 65109f4b4..cf6b7b72c 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -42,7 +42,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { } def apply(rel: Rel, maxFields: Int): String = { - rel.accept(this, null) + rel.accept(this, EmptyVisitationContext.INSTANCE) } override def visit(fetch: Fetch, context: EmptyVisitationContext): String = { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 63ba4c0e7..3630a6e70 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -65,7 +65,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) val function = measure.getFunction var arguments = function.arguments().asScala.zipWithIndex.map { case (arg, i) => - arg.accept(function.declaration(), i, expressionConverter, null) + arg.accept(function.declaration(), i, expressionConverter, EmptyVisitationContext.INSTANCE) } if (function.declaration.name == "count" && function.arguments.size == 0) { // HACK - count() needs to be rewritten as count(1) @@ -92,7 +92,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) }) val filter = Option(measure.getPreMeasureFilter.orElse(null)) - .map(_.accept(expressionConverter, null)) + .map(_.accept(expressionConverter, EmptyVisitationContext.INSTANCE)) AggregateExpression( aggregateFunction, @@ -213,7 +213,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } private def toSortOrder(sortField: SExpression.SortField): SortOrder = { - val expression = sortField.expr().accept(expressionConverter, null) + val expression = sortField.expr().accept(expressionConverter, EmptyVisitationContext.INSTANCE) val (direction, nullOrdering) = sortField.direction() match { case SExpression.SortDirection.ASC_NULLS_FIRST => (Ascending, NullsFirst) case SExpression.SortDirection.DESC_NULLS_FIRST => (Descending, NullsFirst) @@ -449,7 +449,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } def convert(rel: relation.Rel): LogicalPlan = { - val logicalPlan = rel.accept(this, null) + val logicalPlan = rel.accept(this, EmptyVisitationContext.INSTANCE) require(logicalPlan.resolved) logicalPlan } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index ab6d9efef..7210cacc5 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -43,6 +43,7 @@ import io.substrait.relation.RelProtoConverter import io.substrait.relation.Set.SetOp import io.substrait.relation.files.{FileFormat, FileOrFiles} import io.substrait.relation.files.FileOrFiles.PathType +import io.substrait.util.EmptyVisitationContext import io.substrait.utils.Util import java.util @@ -575,7 +576,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { proto.PlanRel .newBuilder() .setRel(substraitRel - .accept(relProtoConverter, null)) + .accept(relProtoConverter, EmptyVisitationContext.INSTANCE)) ) extensionCollector.addExtensionsToPlan(builder) builder.build().toByteArray diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index 491c05fa1..e37d2517b 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -29,6 +29,7 @@ import io.substrait.extension.ExtensionCollector import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter} import io.substrait.proto import io.substrait.relation.{ProtoRelConverter, RelProtoConverter} +import io.substrait.util.EmptyVisitationContext import org.scalactic.Equality import org.scalactic.source.Position import org.scalatest.Succeeded @@ -72,7 +73,7 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => // convert substrait back to spark plan val toLogicalPlan = new ToLogicalPlan(spark); - val sparkPlan2 = substraitRel2.accept(toLogicalPlan, null) + val sparkPlan2 = substraitRel2.accept(toLogicalPlan, EmptyVisitationContext.INSTANCE) require(sparkPlan2.resolved) // and back to substrait again diff --git a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala index e649487cc..90061a167 100644 --- a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala +++ b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala @@ -10,6 +10,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String +import io.substrait.util.EmptyVisitationContext + import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} class TypesAndLiteralsSuite extends SparkFunSuite { @@ -101,7 +103,9 @@ class TypesAndLiteralsSuite extends SparkFunSuite { l => { test(s"test literal: $l (${l.dataType})") { val substraitLiteral = ToSubstraitLiteral.convert(l).get - val sparkLiteral = substraitLiteral.accept(toSparkExpression, null).asInstanceOf[Literal] + val sparkLiteral = substraitLiteral + .accept(toSparkExpression, EmptyVisitationContext.INSTANCE) + .asInstanceOf[Literal] println("Before: " + l + " " + l.dataType) println("After: " + sparkLiteral + " " + sparkLiteral.dataType) @@ -118,7 +122,9 @@ class TypesAndLiteralsSuite extends SparkFunSuite { MapType(IntegerType, StringType, valueContainsNull = false)) val substraitLiteral = ToSubstraitLiteral.convert(l).get - val sparkLiteral = substraitLiteral.accept(toSparkExpression, null).asInstanceOf[Literal] + val sparkLiteral = substraitLiteral + .accept(toSparkExpression, EmptyVisitationContext.INSTANCE) + .asInstanceOf[Literal] println("Before: " + l + " " + l.dataType) println("After: " + sparkLiteral + " " + sparkLiteral.dataType) diff --git a/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala index fa0381b1b..76bda680a 100644 --- a/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala @@ -21,6 +21,7 @@ import io.substrait.spark.SparkExtension import org.apache.spark.sql.catalyst.expressions.Expression import io.substrait.expression.{Expression => SExpression} +import io.substrait.util.EmptyVisitationContext import org.scalatest.Assertions.assertResult trait SubstraitExpressionTestBase { @@ -48,7 +49,8 @@ trait SubstraitExpressionTestBase { f(substraitExp) if (bidirectional) { - val convertedExpression = substraitExp.accept(toSparkExpression, null) + val convertedExpression = + substraitExp.accept(toSparkExpression, EmptyVisitationContext.INSTANCE) assertResult(expression)(convertedExpression) } }