From eb5c57522ef87338b8e78f4ce0f2d38aa0b2e2f9 Mon Sep 17 00:00:00 2001 From: "Mark S. Lewis" Date: Wed, 9 Jul 2025 18:43:15 +0100 Subject: [PATCH 1/2] chore: avoid Java version mismatch When developing, the build (both on the command-line and in the IDE) frequently fails due to usage of unsupported Java language features for the target Java runtime. This is due to the the Java source version being set to a newer version than the target Java version. While this sometimes works, this combination is explicitly disallowed by the Java compiler. Source version must be equal to or less than the target version. This change uses the correct Java toolchain version for each sub-project. The exception being the core sub-project, since it targets Java 8 whereas ANTLR requires Java 11+ to run. Instead, the Java compiler release option is set to Java 8, ensuring that Java 8 bytecode is produced and that no APIs not present in Java 8 are used. While numerous code changes are required to adhere to the target Java language version, there are no functional changes. Signed-off-by: Mark S. Lewis --- build.gradle.kts | 14 +- core/build.gradle.kts | 12 +- .../io/substrait/dsl/SubstraitBuilder.java | 68 +- .../io/substrait/expression/Expression.java | 12 +- .../expression/ExpressionCreator.java | 9 +- .../substrait/expression/FieldReference.java | 2 +- .../io/substrait/expression/FunctionArg.java | 36 +- .../proto/ExpressionProtoConverter.java | 56 +- .../proto/ProtoExpressionConverter.java | 667 ++++++++++-------- .../ExtendedExpressionProtoConverter.java | 11 +- .../extension/AbstractExtensionLookup.java | 8 +- .../extension/AdvancedExtension.java | 3 +- .../extension/ExtensionCollector.java | 38 +- .../extension/ImmutableExtensionLookup.java | 6 +- .../substrait/extension/SimpleExtension.java | 29 +- .../src/main/java/io/substrait/hint/Hint.java | 3 +- .../io/substrait/relation/AbstractDdlRel.java | 4 +- .../substrait/relation/AbstractWriteRel.java | 6 +- .../AggregateFunctionProtoConverter.java | 13 +- .../java/io/substrait/relation/Expand.java | 4 +- .../ExpressionCopyOnWriteVisitor.java | 42 +- .../main/java/io/substrait/relation/Join.java | 68 +- .../substrait/relation/ProtoRelConverter.java | 310 ++++---- .../relation/RelCopyOnWriteVisitor.java | 107 +-- .../substrait/relation/RelProtoConverter.java | 86 ++- .../main/java/io/substrait/relation/Set.java | 28 +- .../substrait/relation/VirtualTableScan.java | 4 +- .../substrait/relation/files/FileOrFiles.java | 45 +- .../substrait/relation/physical/HashJoin.java | 44 +- .../relation/physical/MergeJoin.java | 44 +- .../relation/physical/NestedLoopJoin.java | 44 +- .../java/io/substrait/type/Deserializers.java | 2 +- .../java/io/substrait/type/NamedStruct.java | 4 +- .../main/java/io/substrait/type/YamlRead.java | 3 +- .../io/substrait/type/parser/ParseToPojo.java | 112 +-- .../type/parser/TypeStringParser.java | 6 +- .../type/proto/BaseProtoConverter.java | 2 +- .../proto/ParameterizedProtoConverter.java | 109 +-- .../type/proto/ProtoTypeConverter.java | 132 ++-- .../proto/TypeExpressionProtoVisitor.java | 158 +++-- .../type/proto/TypeProtoConverter.java | 110 +-- .../extension/TypeExtensionTest.java | 8 +- .../relation/ProtoRelConverterTest.java | 4 +- .../type/proto/AggregateRoundtripTest.java | 20 +- ...istentPartitionWindowRelRoundtripTest.java | 6 +- .../type/proto/ExtensionRoundtripTest.java | 6 +- .../type/proto/GenericRoundtripTest.java | 4 +- .../type/proto/IfThenRoundtripTest.java | 10 +- .../type/proto/LiteralRoundtripTest.java | 8 +- .../type/proto/LocalFilesRoundtripTest.java | 74 +- .../type/proto/ReadRelRoundtripTest.java | 10 +- .../type/proto/TestTypeRoundtrip.java | 2 +- .../isthmus/cli/IsthmusEntryPoint.java | 11 +- isthmus/build.gradle.kts | 4 +- .../substrait/isthmus/AggregateFunctions.java | 3 +- .../isthmus/PreCalciteAggregateValidator.java | 13 +- .../io/substrait/isthmus/RelNodeVisitor.java | 64 +- .../isthmus/SqlExpressionToSubstrait.java | 32 +- .../isthmus/SubstraitRelNodeConverter.java | 155 ++-- .../isthmus/SubstraitRelVisitor.java | 93 ++- .../io/substrait/isthmus/TypeConverter.java | 154 ++-- .../AggregateFunctionConverter.java | 2 +- .../isthmus/expression/CallConverters.java | 7 +- .../isthmus/expression/EnumConverter.java | 30 +- .../expression/ExpressionRexConverter.java | 87 ++- .../expression/FieldSelectionConverter.java | 16 +- .../isthmus/expression/FunctionConverter.java | 25 +- .../isthmus/expression/LiteralConverter.java | 231 +++--- .../expression/RexExpressionConverter.java | 41 +- .../expression/ScalarFunctionConverter.java | 2 +- .../expression/SortFieldConverter.java | 37 +- .../expression/WindowBoundConverter.java | 27 +- .../WindowRelFunctionConverter.java | 2 +- .../sql/SubstraitCreateStatementParser.java | 8 +- .../isthmus/AggregationFunctionsTest.java | 23 +- .../substrait/isthmus/ApplyJoinPlanTest.java | 39 +- .../io/substrait/isthmus/CalciteObjs.java | 16 +- .../io/substrait/isthmus/ComplexSortTest.java | 56 +- .../substrait/isthmus/NameRoundtripTest.java | 4 +- .../isthmus/NestedStructQueryTest.java | 289 ++++---- .../isthmus/ProtoPlanConverterTest.java | 38 +- .../io/substrait/isthmus/utils/SetUtils.java | 31 +- spark/build.gradle.kts | 3 +- 83 files changed, 2315 insertions(+), 1841 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 667ffb892..ec1b80873 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -26,8 +26,6 @@ dependencies { implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}") - annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } val submodulesUpdate by @@ -41,20 +39,10 @@ allprojects { repositories { mavenCentral() } tasks.configureEach { - val javaToolchains = project.extensions.getByType() useJUnitPlatform() - javaLauncher.set(javaToolchains.launcherFor { languageVersion.set(JavaLanguageVersion.of(11)) }) testLogging { exceptionFormat = TestExceptionFormat.FULL } } - tasks.withType { - sourceCompatibility = "17" - if (project.name != "core") { - options.release.set(11) - } else { - options.release.set(8) - } - dependsOn(submodulesUpdate) - } + tasks.withType { dependsOn(submodulesUpdate) } group = "io.substrait" version = "${version}" diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 97b30a3f1..150bf8b2b 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -128,8 +128,6 @@ dependencies { implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}") - annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } configurations[JavaPlugin.API_CONFIGURATION_NAME].let { apiConfiguration -> @@ -233,12 +231,12 @@ tasks { jar { manifest { from("build/generated/sources/manifest/META-INF/MANIFEST.MF") } } } +// Set the release instead of using a Java 8 toolchain since ANTLR requires Java 11+ to run +tasks.withType().configureEach { options.release = 8 } + java { - toolchain { - languageVersion.set(JavaLanguageVersion.of(17)) - withJavadocJar() - withSourcesJar() - } + withJavadocJar() + withSourcesJar() } configurations { runtimeClasspath { resolutionStrategy.activateDependencyLocking() } } diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index f2985ff4a..9de3a0faf 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -1,6 +1,5 @@ package io.substrait.dsl; -import com.github.bsideup.jabel.Desugar; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.Expression.Cast; @@ -87,8 +86,8 @@ private Aggregate aggregate( Function> measuresFn, Optional remap, Rel input) { - var groupings = groupingsFn.apply(input); - var measures = measuresFn.apply(input); + List groupings = groupingsFn.apply(input); + List measures = measuresFn.apply(input); return Aggregate.builder() .groupings(groupings) .measures(measures) @@ -147,12 +146,27 @@ public Filter filter(Function conditionFn, Rel.Remap remap, Rel private Filter filter( Function conditionFn, Optional remap, Rel input) { - var condition = conditionFn.apply(input); + Expression condition = conditionFn.apply(input); return Filter.builder().input(input).condition(condition).remap(remap).build(); } - @Desugar - public record JoinInput(Rel left, Rel right) {} + public static final class JoinInput { + private final Rel left; + private final Rel right; + + JoinInput(Rel left, Rel right) { + this.left = left; + this.right = right; + } + + public Rel left() { + return left; + } + + public Rel right() { + return right; + } + } public Join innerJoin(Function conditionFn, Rel left, Rel right) { return join(conditionFn, Join.JoinType.INNER, left, right); @@ -183,7 +197,7 @@ private Join join( Optional remap, Rel left, Rel right) { - var condition = conditionFn.apply(new JoinInput(left, right)); + Expression condition = conditionFn.apply(new JoinInput(left, right)); return Join.builder() .left(left) .right(right) @@ -263,7 +277,7 @@ private NestedLoopJoin nestedLoopJoin( Optional remap, Rel left, Rel right) { - var condition = conditionFn.apply(new JoinInput(left, right)); + Expression condition = conditionFn.apply(new JoinInput(left, right)); return NestedLoopJoin.builder() .left(left) .right(right) @@ -291,8 +305,8 @@ private NamedScan namedScan( Iterable columnNames, Iterable types, Optional remap) { - var struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); - var namedStruct = NamedStruct.of(columnNames, struct); + Type.Struct struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); + NamedStruct namedStruct = NamedStruct.of(columnNames, struct); return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); } @@ -315,7 +329,7 @@ private Project project( Function> expressionsFn, Optional remap, Rel input) { - var expressions = expressionsFn.apply(input); + Iterable expressions = expressionsFn.apply(input); return Project.builder().input(input).expressions(expressions).remap(remap).build(); } @@ -332,7 +346,7 @@ private Expand expand( Function> fieldsFn, Optional remap, Rel input) { - var fields = fieldsFn.apply(input); + Iterable fields = fieldsFn.apply(input); return Expand.builder().input(input).fields(fields).remap(remap).build(); } @@ -363,7 +377,7 @@ private Sort sort( Function> sortFieldFn, Optional remap, Rel input) { - var condition = sortFieldFn.apply(input); + Iterable condition = sortFieldFn.apply(input); return Sort.builder().input(input).sortFields(condition).remap(remap).build(); } @@ -465,7 +479,7 @@ public Switch switchExpression( public AggregateFunctionInvocation aggregateFn( String namespace, String key, Type outputType, Expression... args) { - var declaration = + SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(namespace, key)); return AggregateFunctionInvocation.builder() .arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList())) @@ -477,7 +491,7 @@ public AggregateFunctionInvocation aggregateFn( } public Aggregate.Grouping grouping(Rel input, int... indexes) { - var columns = fieldReferences(input, indexes); + List columns = fieldReferences(input, indexes); return Aggregate.Grouping.builder().addAllExpressions(columns).build(); } @@ -486,7 +500,7 @@ public Aggregate.Grouping grouping(Expression... expressions) { } public Aggregate.Measure count(Rel input, int field) { - var declaration = + SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any")); @@ -563,7 +577,7 @@ public Aggregate.Measure sum0(Expression expr) { private Aggregate.Measure singleArgumentArithmeticAggregate( Expression expr, String functionName, Type outputType) { String typeString = ToTypeString.apply(expr.getType()); - var declaration = + SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, @@ -585,7 +599,7 @@ private Aggregate.Measure singleArgumentArithmeticAggregate( public Expression.ScalarFunctionInvocation negate(Expression expr) { // output type of negate is the same as the input type - var outputType = expr.getType(); + Type outputType = expr.getType(); return scalarFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, String.format("negate:%s", ToTypeString.apply(outputType)), @@ -611,12 +625,12 @@ public Expression.ScalarFunctionInvocation divide(Expression left, Expression ri private Expression.ScalarFunctionInvocation arithmeticFunction( String fname, Expression left, Expression right) { - var leftTypeStr = ToTypeString.apply(left.getType()); - var rightTypeStr = ToTypeString.apply(right.getType()); - var key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr); + String leftTypeStr = ToTypeString.apply(left.getType()); + String rightTypeStr = ToTypeString.apply(right.getType()); + String key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr); - var isOutputNullable = left.getType().nullable() || right.getType().nullable(); - var outputType = left.getType(); + boolean isOutputNullable = left.getType().nullable() || right.getType().nullable(); + Type outputType = left.getType(); outputType = isOutputNullable ? TypeCreator.asNullable(outputType) @@ -633,14 +647,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig public Expression.ScalarFunctionInvocation or(Expression... args) { // If any arg is nullable, the output of or is potentially nullable // For example: false or null = null - var isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); - var outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; + boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); + Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args); } public Expression.ScalarFunctionInvocation scalarFn( String namespace, String key, Type outputType, FunctionArg... args) { - var declaration = + SimpleExtension.ScalarFunctionVariant declaration = extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, key)); return Expression.ScalarFunctionInvocation.builder() .declaration(declaration) @@ -659,7 +673,7 @@ public Expression.WindowFunctionInvocation windowFn( WindowBound lowerBound, WindowBound upperBound, Expression... args) { - var declaration = + SimpleExtension.WindowFunctionVariant declaration = extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(namespace, key)); return Expression.WindowFunctionInvocation.builder() .declaration(declaration) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 75e003f53..913a80994 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -833,7 +833,7 @@ public io.substrait.proto.Expression.WindowFunction.BoundsType toProto() { public static WindowBoundsType fromProto( io.substrait.proto.Expression.WindowFunction.BoundsType proto) { - for (var v : values()) { + for (WindowBoundsType v : values()) { if (v.proto == proto) { return v; } @@ -984,7 +984,7 @@ public io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp toProto() public static PredicateOp fromProto( io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto) { - for (var v : values()) { + for (PredicateOp v : values()) { if (v.proto == proto) { return v; } @@ -1010,7 +1010,7 @@ public io.substrait.proto.AggregateFunction.AggregationInvocation toProto() { } public static AggregationInvocation fromProto(AggregateFunction.AggregationInvocation proto) { - for (var v : values()) { + for (AggregationInvocation v : values()) { if (v.proto == proto) { return v; } @@ -1041,7 +1041,7 @@ public io.substrait.proto.AggregationPhase toProto() { } public static AggregationPhase fromProto(io.substrait.proto.AggregationPhase proto) { - for (var v : values()) { + for (AggregationPhase v : values()) { if (v.proto == proto) { return v; } @@ -1069,7 +1069,7 @@ public io.substrait.proto.SortField.SortDirection toProto() { } public static SortDirection fromProto(io.substrait.proto.SortField.SortDirection proto) { - for (var v : values()) { + for (SortDirection v : values()) { if (v.proto == proto) { return v; } @@ -1097,7 +1097,7 @@ public io.substrait.proto.Expression.Cast.FailureBehavior toProto() { public static FailureBehavior fromProto( io.substrait.proto.Expression.Cast.FailureBehavior proto) { - for (var v : values()) { + for (FailureBehavior v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 55e71b78d..6480c5bec 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -7,6 +7,7 @@ import io.substrait.type.Type; import io.substrait.util.DecimalUtil; import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneOffset; @@ -89,7 +90,7 @@ public static Expression.TimestampLiteral timestamp(boolean nullable, long value */ @Deprecated public static Expression.TimestampLiteral timestamp(boolean nullable, LocalDateTime value) { - var epochMicro = + long epochMicro = TimeUnit.SECONDS.toMicros(value.toEpochSecond(ZoneOffset.UTC)) + TimeUnit.NANOSECONDS.toMicros(value.toLocalTime().getNano()); return timestamp(nullable, epochMicro); @@ -127,7 +128,7 @@ public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, long v */ @Deprecated public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, Instant value) { - var epochMicro = + long epochMicro = TimeUnit.SECONDS.toMicros(value.getEpochSecond()) + TimeUnit.NANOSECONDS.toMicros(value.getNano()); return timestampTZ(nullable, epochMicro); @@ -195,7 +196,7 @@ public static Expression.IntervalCompoundLiteral intervalCompound( } public static Expression.UUIDLiteral uuid(boolean nullable, ByteString uuid) { - var bb = uuid.asReadOnlyByteBuffer(); + ByteBuffer bb = uuid.asReadOnlyByteBuffer(); return Expression.UUIDLiteral.builder() .nullable(nullable) .value(new UUID(bb.getLong(), bb.getLong())) @@ -237,7 +238,7 @@ public static Expression.DecimalLiteral decimal( public static Expression.DecimalLiteral decimal( boolean nullable, BigDecimal value, int precision, int scale) { - var twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16); + byte[] twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16); return Expression.DecimalLiteral.builder() .nullable(nullable) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index f2926f473..b5c42fdd7 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -225,7 +225,7 @@ private static FieldReference of( Collections.reverse(segments); for (int i = 0; i < segments.size(); i++) { if (i == 0) { - var last = segments.get(0); + ReferenceSegment last = segments.get(0); reference = struct == null ? last.constructOnExpression(expression) : last.constructOnRoot(struct); } else { diff --git a/core/src/main/java/io/substrait/expression/FunctionArg.java b/core/src/main/java/io/substrait/expression/FunctionArg.java index 495def8ad..95a437fd5 100644 --- a/core/src/main/java/io/substrait/expression/FunctionArg.java +++ b/core/src/main/java/io/substrait/expression/FunctionArg.java @@ -34,13 +34,13 @@ static FuncArgVisitor expressionVisitor) { - return new FuncArgVisitor<>() { + return new FuncArgVisitor() { @Override public FunctionArgument visitExpr( SimpleExtension.Function fnDef, int argIdx, Expression e, EmptyVisitationContext context) throws RuntimeException { - var pE = e.accept(expressionVisitor, context); + io.substrait.proto.Expression pE = e.accept(expressionVisitor, context); return FunctionArgument.newBuilder().setValue(pE).build(); } @@ -48,7 +48,7 @@ public FunctionArgument visitExpr( public FunctionArgument visitType( SimpleExtension.Function fnDef, int argIdx, Type t, EmptyVisitationContext context) throws RuntimeException { - var pTyp = t.accept(typeVisitor); + io.substrait.proto.Type pTyp = t.accept(typeVisitor); return FunctionArgument.newBuilder().setType(pTyp).build(); } @@ -56,7 +56,7 @@ public FunctionArgument visitType( public FunctionArgument visitEnumArg( SimpleExtension.Function fnDef, int argIdx, EnumArg ea, EmptyVisitationContext context) throws RuntimeException { - var enumBldr = FunctionArgument.newBuilder(); + FunctionArgument.Builder enumBldr = FunctionArgument.newBuilder(); if (ea.value().isPresent()) { enumBldr = enumBldr.setEnum(ea.value().get()); @@ -82,18 +82,22 @@ public ProtoFrom( public FunctionArg convert( SimpleExtension.Function funcDef, int argIdx, FunctionArgument fArg) { - return switch (fArg.getArgTypeCase()) { - case TYPE -> protoTypeConverter.from(fArg.getType()); - case VALUE -> protoExprConverter.from(fArg.getValue()); - case ENUM -> { - SimpleExtension.EnumArgument enumArgDef = - (SimpleExtension.EnumArgument) funcDef.args().get(argIdx); - var optionValue = fArg.getEnum(); - yield EnumArg.of(enumArgDef, optionValue); - } - default -> throw new UnsupportedOperationException( - String.format("Unable to convert FunctionArgument %s.", fArg)); - }; + switch (fArg.getArgTypeCase()) { + case TYPE: + return protoTypeConverter.from(fArg.getType()); + case VALUE: + return protoExprConverter.from(fArg.getValue()); + case ENUM: + { + SimpleExtension.EnumArgument enumArgDef = + (SimpleExtension.EnumArgument) funcDef.args().get(argIdx); + String optionValue = fArg.getEnum(); + return EnumArg.of(enumArgDef, optionValue); + } + default: + throw new UnsupportedOperationException( + String.format("Unable to convert FunctionArgument %s.", fArg)); + } } } } diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 146243e18..0e7593d27 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -70,7 +70,7 @@ public Expression visit( } private Expression lit(Consumer consumer) { - var builder = Expression.Literal.newBuilder(); + Expression.Literal.Builder builder = Expression.Literal.newBuilder(); consumer.accept(builder); return Expression.newBuilder().setLiteral(builder).build(); } @@ -278,12 +278,12 @@ public Expression visit( io.substrait.expression.Expression.MapLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var keyValues = + List keyValues = expr.values().entrySet().stream() .map( e -> { - var key = toLiteral(e.getKey()); - var value = toLiteral(e.getValue()); + Expression.Literal key = toLiteral(e.getKey()); + Expression.Literal value = toLiteral(e.getValue()); return Expression.Literal.Map.KeyValue.newBuilder() .setKey(key) .setValue(value) @@ -300,7 +300,7 @@ public Expression visit( io.substrait.expression.Expression.EmptyMapLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var protoMapType = toProto(expr.getType()); + Type protoMapType = toProto(expr.getType()); bldr.setEmptyMap(protoMapType.getMap()) // For empty maps, the Literal message's own nullable field should be ignored // in favor of the nullability of the Type.Map in the literal's @@ -316,7 +316,7 @@ public Expression visit( io.substrait.expression.Expression.ListLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var values = + List values = expr.values().stream() .map(this::toLiteral) .collect(java.util.stream.Collectors.toList()); @@ -331,7 +331,7 @@ public Expression visit( throws RuntimeException { return lit( builder -> { - var protoListType = toProto(expr.getType()); + Type protoListType = toProto(expr.getType()); builder .setEmptyList(protoListType.getList()) // For empty lists, the Literal message's own nullable field should be ignored @@ -348,7 +348,7 @@ public Expression visit( io.substrait.expression.Expression.StructLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var values = + List values = expr.fields().stream() .map(this::toLiteral) .collect(java.util.stream.Collectors.toList()); @@ -360,7 +360,7 @@ public Expression visit( @Override public Expression visit( io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) { - var typeReference = + int typeReference = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); return lit( bldr -> { @@ -378,7 +378,7 @@ public Expression visit( } private Expression.Literal toLiteral(io.substrait.expression.Expression expression) { - var e = toProto(expression); + Expression e = toProto(expression); assert e.getRexTypeCase() == Expression.RexTypeCase.LITERAL; return e.getLiteral(); } @@ -386,7 +386,7 @@ private Expression.Literal toLiteral(io.substrait.expression.Expression expressi @Override public Expression visit( io.substrait.expression.Expression.Switch expr, EmptyVisitationContext context) { - var clauses = + List clauses = expr.switchClauses().stream() .map( s -> @@ -407,7 +407,7 @@ public Expression visit( @Override public Expression visit( io.substrait.expression.Expression.IfThen expr, EmptyVisitationContext context) { - var clauses = + List clauses = expr.ifClauses().stream() .map( s -> @@ -427,7 +427,8 @@ public Expression visit( io.substrait.expression.Expression.ScalarFunctionInvocation expr, EmptyVisitationContext context) { - var argVisitor = FunctionArg.toProto(typeProtoConverter, this); + FunctionArg.FuncArgVisitor + argVisitor = FunctionArg.toProto(typeProtoConverter, this); return Expression.newBuilder() .setScalarFunction( @@ -499,22 +500,28 @@ public Expression visit( public Expression visit(FieldReference expr, EmptyVisitationContext context) { Expression.ReferenceSegment seg = null; - for (var segment : expr.segments()) { + for (FieldReference.ReferenceSegment segment : expr.segments()) { Expression.ReferenceSegment.Builder protoSegment; - if (segment instanceof FieldReference.StructField f) { - var bldr = Expression.ReferenceSegment.StructField.newBuilder().setField(f.offset()); + if (segment instanceof FieldReference.StructField) { + FieldReference.StructField f = (FieldReference.StructField) segment; + Expression.ReferenceSegment.StructField.Builder bldr = + Expression.ReferenceSegment.StructField.newBuilder().setField(f.offset()); if (seg != null) { bldr.setChild(seg); } protoSegment = Expression.ReferenceSegment.newBuilder().setStructField(bldr); - } else if (segment instanceof FieldReference.ListElement f) { - var bldr = Expression.ReferenceSegment.ListElement.newBuilder().setOffset(f.offset()); + } else if (segment instanceof FieldReference.ListElement) { + FieldReference.ListElement f = (FieldReference.ListElement) segment; + Expression.ReferenceSegment.ListElement.Builder bldr = + Expression.ReferenceSegment.ListElement.newBuilder().setOffset(f.offset()); if (seg != null) { bldr.setChild(seg); } protoSegment = Expression.ReferenceSegment.newBuilder().setListElement(bldr); - } else if (segment instanceof FieldReference.MapKey f) { - var bldr = Expression.ReferenceSegment.MapKey.newBuilder().setMapKey(toLiteral(f.key())); + } else if (segment instanceof FieldReference.MapKey) { + FieldReference.MapKey f = (FieldReference.MapKey) segment; + Expression.ReferenceSegment.MapKey.Builder bldr = + Expression.ReferenceSegment.MapKey.newBuilder().setMapKey(toLiteral(f.key())); if (seg != null) { bldr.setChild(seg); } @@ -522,11 +529,11 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { } else { throw new IllegalArgumentException("Unhandled type: " + segment); } - var builtSegment = protoSegment.build(); - seg = builtSegment; + seg = protoSegment.build(); } - var out = Expression.FieldReference.newBuilder().setDirectReference(seg); + Expression.FieldReference.Builder out = + Expression.FieldReference.newBuilder().setDirectReference(seg); if (expr.inputExpression().isPresent()) { out.setExpression(toProto(expr.inputExpression().get())); @@ -591,7 +598,8 @@ public Expression visit( io.substrait.expression.Expression.WindowFunctionInvocation expr, EmptyVisitationContext context) throws RuntimeException { - var argVisitor = FunctionArg.toProto(typeProtoConverter, this); + FunctionArg.FuncArgVisitor + argVisitor = FunctionArg.toProto(typeProtoConverter, this); List args = expr.arguments().stream() .map(a -> a.accept(expr.declaration(), 0, argVisitor, context)) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 4a4d6a4fc..52f780f9a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -3,12 +3,14 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; +import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; import io.substrait.proto.ConsistentPartitionWindowRel; +import io.substrait.proto.Expression.FieldReference.ReferenceTypeCase; import io.substrait.proto.FunctionArgument; import io.substrait.proto.SortField; import io.substrait.relation.ConsistentPartitionWindow; @@ -53,211 +55,241 @@ public ProtoExpressionConverter( } public FieldReference from(io.substrait.proto.Expression.FieldReference reference) { - switch (reference.getReferenceTypeCase()) { - case DIRECT_REFERENCE -> { - io.substrait.proto.Expression.ReferenceSegment segment = reference.getDirectReference(); + io.substrait.proto.Expression.FieldReference.ReferenceTypeCase refTypeCase = + reference.getReferenceTypeCase(); - var segments = new ArrayList(); - while (segment != io.substrait.proto.Expression.ReferenceSegment.getDefaultInstance()) { - segments.add( - switch (segment.getReferenceTypeCase()) { - case MAP_KEY -> { - var mapKey = segment.getMapKey(); - segment = mapKey.getChild(); - yield FieldReference.MapKey.of(from(mapKey.getMapKey())); - } - case STRUCT_FIELD -> { - var structField = segment.getStructField(); - segment = structField.getChild(); - yield FieldReference.StructField.of(structField.getField()); - } - case LIST_ELEMENT -> { - var listElement = segment.getListElement(); - segment = listElement.getChild(); - yield FieldReference.ListElement.of(listElement.getOffset()); - } - case REFERENCETYPE_NOT_SET -> throw new IllegalArgumentException( - "Unhandled type: " + segment.getReferenceTypeCase()); - }); - } - Collections.reverse(segments); - var fieldReference = - switch (reference.getRootTypeCase()) { - case EXPRESSION -> FieldReference.ofExpression( - from(reference.getExpression()), segments); - case ROOT_REFERENCE -> FieldReference.ofRoot(rootType, segments); - case OUTER_REFERENCE -> FieldReference.newRootStructOuterReference( - reference.getDirectReference().getStructField().getField(), - rootType, - reference.getOuterReference().getStepsOut()); - case ROOTTYPE_NOT_SET -> throw new IllegalArgumentException( - "Unhandled type: " + reference.getRootTypeCase()); - }; + if (refTypeCase == ReferenceTypeCase.MASKED_REFERENCE) { + throw new IllegalArgumentException("Unsupported type: " + refTypeCase); + } + + if (refTypeCase != ReferenceTypeCase.DIRECT_REFERENCE) { + throw new IllegalArgumentException("Unhandled type: " + refTypeCase); + } + + switch (reference.getRootTypeCase()) { + case EXPRESSION: + return FieldReference.ofExpression( + from(reference.getExpression()), + getDirectReferenceSegments(reference.getDirectReference())); + case ROOT_REFERENCE: + return FieldReference.ofRoot( + rootType, getDirectReferenceSegments(reference.getDirectReference())); + case OUTER_REFERENCE: + return FieldReference.newRootStructOuterReference( + reference.getDirectReference().getStructField().getField(), + rootType, + reference.getOuterReference().getStepsOut()); + case ROOTTYPE_NOT_SET: + default: + throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase()); + } + } + + private List getDirectReferenceSegments( + io.substrait.proto.Expression.ReferenceSegment segment) { + List results = new ArrayList<>(); - return fieldReference; + while (segment != io.substrait.proto.Expression.ReferenceSegment.getDefaultInstance()) { + final ReferenceSegment mappedSegment; + switch (segment.getReferenceTypeCase()) { + case MAP_KEY: + io.substrait.proto.Expression.ReferenceSegment.MapKey mapKey = segment.getMapKey(); + segment = mapKey.getChild(); + mappedSegment = FieldReference.MapKey.of(from(mapKey.getMapKey())); + break; + case STRUCT_FIELD: + io.substrait.proto.Expression.ReferenceSegment.StructField structField = + segment.getStructField(); + segment = structField.getChild(); + mappedSegment = FieldReference.StructField.of(structField.getField()); + break; + case LIST_ELEMENT: + io.substrait.proto.Expression.ReferenceSegment.ListElement listElement = + segment.getListElement(); + segment = listElement.getChild(); + mappedSegment = FieldReference.ListElement.of(listElement.getOffset()); + break; + case REFERENCETYPE_NOT_SET: + default: + throw new IllegalArgumentException("Unhandled type: " + segment.getReferenceTypeCase()); } - case MASKED_REFERENCE -> throw new IllegalArgumentException( - "Unsupported type: " + reference.getReferenceTypeCase()); - default -> throw new IllegalArgumentException( - "Unhandled type: " + reference.getReferenceTypeCase()); + + results.add(mappedSegment); } + + Collections.reverse(results); + + return results; } public Expression from(io.substrait.proto.Expression expr) { - return switch (expr.getRexTypeCase()) { - case LITERAL -> from(expr.getLiteral()); - case SELECTION -> from(expr.getSelection()); - case SCALAR_FUNCTION -> { - var scalarFunction = expr.getScalarFunction(); - var functionReference = scalarFunction.getFunctionReference(); - var declaration = lookup.getScalarFunction(functionReference, extensions); - var pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); - var args = - IntStream.range(0, scalarFunction.getArgumentsCount()) - .mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))) - .collect(java.util.stream.Collectors.toList()); - var options = - scalarFunction.getOptionsList().stream() - .map(ProtoExpressionConverter::fromFunctionOption) - .collect(Collectors.toList()); - yield Expression.ScalarFunctionInvocation.builder() - .addAllArguments(args) - .declaration(declaration) - .outputType(protoTypeConverter.from(scalarFunction.getOutputType())) - .options(options) - .build(); - } - case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction()); - case IF_THEN -> { - var ifThen = expr.getIfThen(); - var clauses = - ifThen.getIfsList().stream() - .map(t -> ExpressionCreator.ifThenClause(from(t.getIf()), from(t.getThen()))) - .collect(java.util.stream.Collectors.toList()); - yield ExpressionCreator.ifThenStatement(from(ifThen.getElse()), clauses); - } - case SWITCH_EXPRESSION -> { - var switchExpr = expr.getSwitchExpression(); - var clauses = - switchExpr.getIfsList().stream() - .map(t -> ExpressionCreator.switchClause(from(t.getIf()), from(t.getThen()))) - .collect(java.util.stream.Collectors.toList()); - yield ExpressionCreator.switchStatement( - from(switchExpr.getMatch()), from(switchExpr.getElse()), clauses); - } - case SINGULAR_OR_LIST -> { - var orList = expr.getSingularOrList(); - var values = - orList.getOptionsList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList()); - yield Expression.SingleOrList.builder() - .condition(from(orList.getValue())) - .addAllOptions(values) - .build(); - } - case MULTI_OR_LIST -> { - var multiOrList = expr.getMultiOrList(); - var values = - multiOrList.getOptionsList().stream() - .map( - t -> - Expression.MultiOrListRecord.builder() - .addAllValues( - t.getFieldsList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())) - .build()) - .collect(java.util.stream.Collectors.toList()); - yield Expression.MultiOrList.builder() - .addAllOptionCombinations(values) - .addAllConditions( - multiOrList.getValueList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())) - .build(); - } - case CAST -> ExpressionCreator.cast( - protoTypeConverter.from(expr.getCast().getType()), - from(expr.getCast().getInput()), - Expression.FailureBehavior.fromProto(expr.getCast().getFailureBehavior())); - case SUBQUERY -> { - switch (expr.getSubquery().getSubqueryTypeCase()) { - case SET_PREDICATE -> { - var rel = protoRelConverter.from(expr.getSubquery().getSetPredicate().getTuples()); - yield Expression.SetPredicate.builder() - .tuples(rel) - .predicateOp( - Expression.PredicateOp.fromProto( - expr.getSubquery().getSetPredicate().getPredicateOp())) - .build(); - } - case SCALAR -> { - var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput()); - yield Expression.ScalarSubquery.builder() - .input(rel) - .type( - rel.getRecordType() - .accept( - new TypeVisitor.TypeThrowsVisitor( - "Expected struct field") { - @Override - public Type visit(Type.Struct type) throws RuntimeException { - if (type.fields().size() != 1) { - throw new UnsupportedOperationException( - "Scalar subquery must have exactly one field"); - } - // Result can be null if the query returns no rows - return type.fields().get(0); - } - })) - .build(); - } - case IN_PREDICATE -> { - var rel = protoRelConverter.from(expr.getSubquery().getInPredicate().getHaystack()); - var needles = - expr.getSubquery().getInPredicate().getNeedlesList().stream() - .map(e -> this.from(e)) - .collect(java.util.stream.Collectors.toList()); - yield Expression.InPredicate.builder().haystack(rel).needles(needles).build(); - } - case SET_COMPARISON -> { - throw new UnsupportedOperationException( - "Unsupported subquery type: " + expr.getSubquery().getSubqueryTypeCase()); - } - default -> { - throw new IllegalArgumentException( - "Unknown subquery type: " + expr.getSubquery().getSubqueryTypeCase()); + switch (expr.getRexTypeCase()) { + case LITERAL: + return from(expr.getLiteral()); + case SELECTION: + return from(expr.getSelection()); + case SCALAR_FUNCTION: + { + io.substrait.proto.Expression.ScalarFunction scalarFunction = expr.getScalarFunction(); + int functionReference = scalarFunction.getFunctionReference(); + SimpleExtension.ScalarFunctionVariant declaration = + lookup.getScalarFunction(functionReference, extensions); + FunctionArg.ProtoFrom pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); + List args = + IntStream.range(0, scalarFunction.getArgumentsCount()) + .mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))) + .collect(Collectors.toList()); + List options = + scalarFunction.getOptionsList().stream() + .map(ProtoExpressionConverter::fromFunctionOption) + .collect(Collectors.toList()); + return Expression.ScalarFunctionInvocation.builder() + .addAllArguments(args) + .declaration(declaration) + .outputType(protoTypeConverter.from(scalarFunction.getOutputType())) + .options(options) + .build(); + } + case WINDOW_FUNCTION: + return fromWindowFunction(expr.getWindowFunction()); + case IF_THEN: + { + io.substrait.proto.Expression.IfThen ifThen = expr.getIfThen(); + List clauses = + ifThen.getIfsList().stream() + .map(t -> ExpressionCreator.ifThenClause(from(t.getIf()), from(t.getThen()))) + .collect(Collectors.toList()); + return ExpressionCreator.ifThenStatement(from(ifThen.getElse()), clauses); + } + case SWITCH_EXPRESSION: + { + io.substrait.proto.Expression.SwitchExpression switchExpr = expr.getSwitchExpression(); + List clauses = + switchExpr.getIfsList().stream() + .map(t -> ExpressionCreator.switchClause(from(t.getIf()), from(t.getThen()))) + .collect(Collectors.toList()); + return ExpressionCreator.switchStatement( + from(switchExpr.getMatch()), from(switchExpr.getElse()), clauses); + } + case SINGULAR_OR_LIST: + { + io.substrait.proto.Expression.SingularOrList orList = expr.getSingularOrList(); + List values = + orList.getOptionsList().stream().map(this::from).collect(Collectors.toList()); + return Expression.SingleOrList.builder() + .condition(from(orList.getValue())) + .addAllOptions(values) + .build(); + } + case MULTI_OR_LIST: + { + io.substrait.proto.Expression.MultiOrList multiOrList = expr.getMultiOrList(); + List values = + multiOrList.getOptionsList().stream() + .map( + t -> + Expression.MultiOrListRecord.builder() + .addAllValues( + t.getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())) + .build()) + .collect(Collectors.toList()); + return Expression.MultiOrList.builder() + .addAllOptionCombinations(values) + .addAllConditions( + multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList())) + .build(); + } + case CAST: + return ExpressionCreator.cast( + protoTypeConverter.from(expr.getCast().getType()), + from(expr.getCast().getInput()), + Expression.FailureBehavior.fromProto(expr.getCast().getFailureBehavior())); + case SUBQUERY: + { + switch (expr.getSubquery().getSubqueryTypeCase()) { + case SET_PREDICATE: + { + io.substrait.relation.Rel rel = + protoRelConverter.from(expr.getSubquery().getSetPredicate().getTuples()); + return Expression.SetPredicate.builder() + .tuples(rel) + .predicateOp( + Expression.PredicateOp.fromProto( + expr.getSubquery().getSetPredicate().getPredicateOp())) + .build(); + } + case SCALAR: + { + io.substrait.relation.Rel rel = + protoRelConverter.from(expr.getSubquery().getScalar().getInput()); + return Expression.ScalarSubquery.builder() + .input(rel) + .type( + rel.getRecordType() + .accept( + new TypeVisitor.TypeThrowsVisitor( + "Expected struct field") { + @Override + public Type visit(Type.Struct type) throws RuntimeException { + if (type.fields().size() != 1) { + throw new UnsupportedOperationException( + "Scalar subquery must have exactly one field"); + } + // Result can be null if the query returns no rows + return type.fields().get(0); + } + })) + .build(); + } + case IN_PREDICATE: + { + io.substrait.relation.Rel rel = + protoRelConverter.from(expr.getSubquery().getInPredicate().getHaystack()); + List needles = + expr.getSubquery().getInPredicate().getNeedlesList().stream() + .map(e -> this.from(e)) + .collect(Collectors.toList()); + return Expression.InPredicate.builder().haystack(rel).needles(needles).build(); + } + case SET_COMPARISON: + throw new UnsupportedOperationException( + "Unsupported subquery type: " + expr.getSubquery().getSubqueryTypeCase()); + default: + throw new IllegalArgumentException( + "Unknown subquery type: " + expr.getSubquery().getSubqueryTypeCase()); } } - } // TODO enum. - case ENUM -> throw new UnsupportedOperationException( - "Unsupported type: " + expr.getRexTypeCase()); - default -> throw new IllegalArgumentException("Unknown type: " + expr.getRexTypeCase()); - }; + case ENUM: + throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase()); + default: + throw new IllegalArgumentException("Unknown type: " + expr.getRexTypeCase()); + } } public Expression.WindowFunctionInvocation fromWindowFunction( io.substrait.proto.Expression.WindowFunction windowFunction) { - var functionReference = windowFunction.getFunctionReference(); - var declaration = lookup.getWindowFunction(functionReference, extensions); - var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); + int functionReference = windowFunction.getFunctionReference(); + SimpleExtension.WindowFunctionVariant declaration = + lookup.getWindowFunction(functionReference, extensions); + FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); - var args = + List args = fromFunctionArgumentList( windowFunction.getArgumentsCount(), argVisitor, declaration, windowFunction::getArguments); - var partitionExprs = + List partitionExprs = windowFunction.getPartitionsList().stream().map(this::from).collect(Collectors.toList()); - var sortFields = + List sortFields = windowFunction.getSortsList().stream() .map(this::fromSortField) .collect(Collectors.toList()); - var options = + List options = windowFunction.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); @@ -282,17 +314,18 @@ public Expression.WindowFunctionInvocation fromWindowFunction( public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFunction( ConsistentPartitionWindowRel.WindowRelFunction windowRelFunction) { - var functionReference = windowRelFunction.getFunctionReference(); - var declaration = lookup.getWindowFunction(functionReference, extensions); - var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); + int functionReference = windowRelFunction.getFunctionReference(); + SimpleExtension.WindowFunctionVariant declaration = + lookup.getWindowFunction(functionReference, extensions); + FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); - var args = + List args = fromFunctionArgumentList( windowRelFunction.getArgumentsCount(), argVisitor, declaration, windowRelFunction::getArguments); - var options = + List options = windowRelFunction.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); @@ -314,125 +347,163 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti } private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { - return switch (bound.getKindCase()) { - case PRECEDING -> WindowBound.Preceding.of(bound.getPreceding().getOffset()); - case FOLLOWING -> WindowBound.Following.of(bound.getFollowing().getOffset()); - case CURRENT_ROW -> WindowBound.CURRENT_ROW; - case UNBOUNDED -> WindowBound.UNBOUNDED; - case KIND_NOT_SET -> - // per the spec, the lower and upper bounds default to the start or end of the partition - // respectively if not set - WindowBound.UNBOUNDED; - }; + switch (bound.getKindCase()) { + case PRECEDING: + return WindowBound.Preceding.of(bound.getPreceding().getOffset()); + case FOLLOWING: + return WindowBound.Following.of(bound.getFollowing().getOffset()); + case CURRENT_ROW: + return WindowBound.CURRENT_ROW; + case UNBOUNDED: + return WindowBound.UNBOUNDED; + case KIND_NOT_SET: + // per the spec, the lower and upper bounds default to the start or end of the partition + // respectively if not set + return WindowBound.UNBOUNDED; + default: + throw new IllegalArgumentException("Unsupported bound kind: " + bound.getKindCase()); + } } public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { - return switch (literal.getLiteralTypeCase()) { - case BOOLEAN -> ExpressionCreator.bool(literal.getNullable(), literal.getBoolean()); - case I8 -> ExpressionCreator.i8(literal.getNullable(), literal.getI8()); - case I16 -> ExpressionCreator.i16(literal.getNullable(), literal.getI16()); - case I32 -> ExpressionCreator.i32(literal.getNullable(), literal.getI32()); - case I64 -> ExpressionCreator.i64(literal.getNullable(), literal.getI64()); - case FP32 -> ExpressionCreator.fp32(literal.getNullable(), literal.getFp32()); - case FP64 -> ExpressionCreator.fp64(literal.getNullable(), literal.getFp64()); - case STRING -> ExpressionCreator.string(literal.getNullable(), literal.getString()); - case BINARY -> ExpressionCreator.binary(literal.getNullable(), literal.getBinary()); - case TIMESTAMP -> ExpressionCreator.timestamp(literal.getNullable(), literal.getTimestamp()); - case TIMESTAMP_TZ -> ExpressionCreator.timestampTZ( - literal.getNullable(), literal.getTimestampTz()); - case PRECISION_TIMESTAMP -> ExpressionCreator.precisionTimestamp( - literal.getNullable(), - literal.getPrecisionTimestamp().getValue(), - literal.getPrecisionTimestamp().getPrecision()); - case PRECISION_TIMESTAMP_TZ -> ExpressionCreator.precisionTimestampTZ( - literal.getNullable(), - literal.getPrecisionTimestampTz().getValue(), - literal.getPrecisionTimestampTz().getPrecision()); - case DATE -> ExpressionCreator.date(literal.getNullable(), literal.getDate()); - case TIME -> ExpressionCreator.time(literal.getNullable(), literal.getTime()); - case INTERVAL_YEAR_TO_MONTH -> ExpressionCreator.intervalYear( - literal.getNullable(), - literal.getIntervalYearToMonth().getYears(), - literal.getIntervalYearToMonth().getMonths()); - case INTERVAL_DAY_TO_SECOND -> { - // Handle deprecated version that doesn't provide precision and that uses microseconds - // instead of subseconds, for backwards compatibility - int precision = - literal.getIntervalDayToSecond().hasPrecision() - ? literal.getIntervalDayToSecond().getPrecision() - : 6; // microseconds - long subseconds = - literal.getIntervalDayToSecond().hasPrecision() - ? literal.getIntervalDayToSecond().getSubseconds() - : literal.getIntervalDayToSecond().getMicroseconds(); - yield ExpressionCreator.intervalDay( + switch (literal.getLiteralTypeCase()) { + case BOOLEAN: + return ExpressionCreator.bool(literal.getNullable(), literal.getBoolean()); + case I8: + return ExpressionCreator.i8(literal.getNullable(), literal.getI8()); + case I16: + return ExpressionCreator.i16(literal.getNullable(), literal.getI16()); + case I32: + return ExpressionCreator.i32(literal.getNullable(), literal.getI32()); + case I64: + return ExpressionCreator.i64(literal.getNullable(), literal.getI64()); + case FP32: + return ExpressionCreator.fp32(literal.getNullable(), literal.getFp32()); + case FP64: + return ExpressionCreator.fp64(literal.getNullable(), literal.getFp64()); + case STRING: + return ExpressionCreator.string(literal.getNullable(), literal.getString()); + case BINARY: + return ExpressionCreator.binary(literal.getNullable(), literal.getBinary()); + case TIMESTAMP: + return ExpressionCreator.timestamp(literal.getNullable(), literal.getTimestamp()); + case TIMESTAMP_TZ: + return ExpressionCreator.timestampTZ(literal.getNullable(), literal.getTimestampTz()); + case PRECISION_TIMESTAMP: + return ExpressionCreator.precisionTimestamp( literal.getNullable(), - literal.getIntervalDayToSecond().getDays(), - literal.getIntervalDayToSecond().getSeconds(), - subseconds, - precision); - } - case INTERVAL_COMPOUND -> { - if (!literal.getIntervalCompound().getIntervalDayToSecond().hasPrecision()) { - throw new RuntimeException( - "Interval compound with deprecated version of interval day (ie. no precision) is not supported"); + literal.getPrecisionTimestamp().getValue(), + literal.getPrecisionTimestamp().getPrecision()); + case PRECISION_TIMESTAMP_TZ: + return ExpressionCreator.precisionTimestampTZ( + literal.getNullable(), + literal.getPrecisionTimestampTz().getValue(), + literal.getPrecisionTimestampTz().getPrecision()); + case DATE: + return ExpressionCreator.date(literal.getNullable(), literal.getDate()); + case TIME: + return ExpressionCreator.time(literal.getNullable(), literal.getTime()); + case INTERVAL_YEAR_TO_MONTH: + return ExpressionCreator.intervalYear( + literal.getNullable(), + literal.getIntervalYearToMonth().getYears(), + literal.getIntervalYearToMonth().getMonths()); + case INTERVAL_DAY_TO_SECOND: + { + // Handle deprecated version that doesn't provide precision and that uses microseconds + // instead of subseconds, for backwards compatibility + int precision = + literal.getIntervalDayToSecond().hasPrecision() + ? literal.getIntervalDayToSecond().getPrecision() + : 6; // microseconds + long subseconds = + literal.getIntervalDayToSecond().hasPrecision() + ? literal.getIntervalDayToSecond().getSubseconds() + : literal.getIntervalDayToSecond().getMicroseconds(); + return ExpressionCreator.intervalDay( + literal.getNullable(), + literal.getIntervalDayToSecond().getDays(), + literal.getIntervalDayToSecond().getSeconds(), + subseconds, + precision); + } + case INTERVAL_COMPOUND: + { + if (!literal.getIntervalCompound().getIntervalDayToSecond().hasPrecision()) { + throw new RuntimeException( + "Interval compound with deprecated version of interval day (ie. no precision) is not supported"); + } + return ExpressionCreator.intervalCompound( + literal.getNullable(), + literal.getIntervalCompound().getIntervalYearToMonth().getYears(), + literal.getIntervalCompound().getIntervalYearToMonth().getMonths(), + literal.getIntervalCompound().getIntervalDayToSecond().getDays(), + literal.getIntervalCompound().getIntervalDayToSecond().getSeconds(), + literal.getIntervalCompound().getIntervalDayToSecond().getSubseconds(), + literal.getIntervalCompound().getIntervalDayToSecond().getPrecision()); } - yield ExpressionCreator.intervalCompound( + case FIXED_CHAR: + return ExpressionCreator.fixedChar(literal.getNullable(), literal.getFixedChar()); + case VAR_CHAR: + return ExpressionCreator.varChar( literal.getNullable(), - literal.getIntervalCompound().getIntervalYearToMonth().getYears(), - literal.getIntervalCompound().getIntervalYearToMonth().getMonths(), - literal.getIntervalCompound().getIntervalDayToSecond().getDays(), - literal.getIntervalCompound().getIntervalDayToSecond().getSeconds(), - literal.getIntervalCompound().getIntervalDayToSecond().getSubseconds(), - literal.getIntervalCompound().getIntervalDayToSecond().getPrecision()); - } - case FIXED_CHAR -> ExpressionCreator.fixedChar(literal.getNullable(), literal.getFixedChar()); - case VAR_CHAR -> ExpressionCreator.varChar( - literal.getNullable(), literal.getVarChar().getValue(), literal.getVarChar().getLength()); - case FIXED_BINARY -> ExpressionCreator.fixedBinary( - literal.getNullable(), literal.getFixedBinary()); - case DECIMAL -> ExpressionCreator.decimal( - literal.getNullable(), - literal.getDecimal().getValue(), - literal.getDecimal().getPrecision(), - literal.getDecimal().getScale()); - case STRUCT -> ExpressionCreator.struct( - literal.getNullable(), - literal.getStruct().getFieldsList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())); - case MAP -> ExpressionCreator.map( - literal.getNullable(), - literal.getMap().getKeyValuesList().stream() - .collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue())))); - case EMPTY_MAP -> { - // literal.getNullable() is intentionally ignored in favor of the nullability - // specified in the literal.getEmptyMap() type. - var mapType = protoTypeConverter.fromMap(literal.getEmptyMap()); - yield ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value()); - } - case UUID -> ExpressionCreator.uuid(literal.getNullable(), literal.getUuid()); - case NULL -> ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull())); - case LIST -> ExpressionCreator.list( - literal.getNullable(), - literal.getList().getValuesList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())); - case EMPTY_LIST -> { - // literal.getNullable() is intentionally ignored in favor of the nullability - // specified in the literal.getEmptyList() type. - var listType = protoTypeConverter.fromList(literal.getEmptyList()); - yield ExpressionCreator.emptyList(listType.nullable(), listType.elementType()); - } - case USER_DEFINED -> { - var userDefinedLiteral = literal.getUserDefined(); - var type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); - yield ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue()); - } - default -> throw new IllegalStateException( - "Unexpected value: " + literal.getLiteralTypeCase()); - }; + literal.getVarChar().getValue(), + literal.getVarChar().getLength()); + case FIXED_BINARY: + return ExpressionCreator.fixedBinary(literal.getNullable(), literal.getFixedBinary()); + case DECIMAL: + return ExpressionCreator.decimal( + literal.getNullable(), + literal.getDecimal().getValue(), + literal.getDecimal().getPrecision(), + literal.getDecimal().getScale()); + case STRUCT: + return ExpressionCreator.struct( + literal.getNullable(), + literal.getStruct().getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())); + case MAP: + return ExpressionCreator.map( + literal.getNullable(), + literal.getMap().getKeyValuesList().stream() + .collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue())))); + case EMPTY_MAP: + { + // literal.getNullable() is intentionally ignored in favor of the nullability + // specified in the literal.getEmptyMap() type. + Type.Map mapType = protoTypeConverter.fromMap(literal.getEmptyMap()); + return ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value()); + } + case UUID: + return ExpressionCreator.uuid(literal.getNullable(), literal.getUuid()); + case NULL: + return ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull())); + case LIST: + return ExpressionCreator.list( + literal.getNullable(), + literal.getList().getValuesList().stream() + .map(this::from) + .collect(Collectors.toList())); + case EMPTY_LIST: + { + // literal.getNullable() is intentionally ignored in favor of the nullability + // specified in the literal.getEmptyList() type. + Type.ListType listType = protoTypeConverter.fromList(literal.getEmptyList()); + return ExpressionCreator.emptyList(listType.nullable(), listType.elementType()); + } + case USER_DEFINED: + { + io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = + literal.getUserDefined(); + SimpleExtension.Type type = + lookup.getType(userDefinedLiteral.getTypeReference(), extensions); + return ExpressionCreator.userDefinedLiteral( + literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue()); + } + default: + throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); + } } private static List fromFunctionArgumentList( diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index 00c340b53..e0cb9d1c5 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -26,7 +26,10 @@ public ExtendedExpression toProto( for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReferenceBase expressionReference : extendedExpression.getReferredExpressions()) { if (expressionReference - instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et) { + instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference) { + io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et = + (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference) + expressionReference; io.substrait.proto.Expression expressionProto = et.getExpression().accept(expressionProtoConverter, EmptyVisitationContext.INSTANCE); ExpressionReference.Builder expressionReferenceBuilder = @@ -36,8 +39,10 @@ public ExtendedExpression toProto( builder.addReferredExpr(expressionReferenceBuilder); } else if (expressionReference instanceof - io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference - aft) { + io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference) { + io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference aft = + (io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference) + expressionReference; ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setMeasure( diff --git a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java index 51aa38ebd..a2ad68f70 100644 --- a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java @@ -15,7 +15,7 @@ public AbstractExtensionLookup( public SimpleExtension.ScalarFunctionVariant getScalarFunction( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = functionAnchorMap.get(reference); + SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -26,7 +26,7 @@ public SimpleExtension.ScalarFunctionVariant getScalarFunction( public SimpleExtension.WindowFunctionVariant getWindowFunction( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = functionAnchorMap.get(reference); + SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -37,7 +37,7 @@ public SimpleExtension.WindowFunctionVariant getWindowFunction( public SimpleExtension.AggregateFunctionVariant getAggregateFunction( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = functionAnchorMap.get(reference); + SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -48,7 +48,7 @@ public SimpleExtension.AggregateFunctionVariant getAggregateFunction( public SimpleExtension.Type getType( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = typeAnchorMap.get(reference); + SimpleExtension.TypeAnchor anchor = typeAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown type id. Make sure that the type id provided was shared in the extensions section of the plan."); diff --git a/core/src/main/java/io/substrait/extension/AdvancedExtension.java b/core/src/main/java/io/substrait/extension/AdvancedExtension.java index fb5370244..bd717a636 100644 --- a/core/src/main/java/io/substrait/extension/AdvancedExtension.java +++ b/core/src/main/java/io/substrait/extension/AdvancedExtension.java @@ -14,7 +14,8 @@ public abstract class AdvancedExtension { public abstract Optional getEnhancement(); public io.substrait.proto.AdvancedExtension toProto(RelProtoConverter relProtoConverter) { - var builder = io.substrait.proto.AdvancedExtension.newBuilder(); + io.substrait.proto.AdvancedExtension.Builder builder = + io.substrait.proto.AdvancedExtension.newBuilder(); getEnhancement().ifPresent(e -> builder.setEnhancement(e.toProto(relProtoConverter))); getOptimizations().forEach(e -> builder.addOptimization(e.toProto(relProtoConverter))); return builder.build(); diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 5cfb03589..3dc0fe7c1 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,6 +1,5 @@ package io.substrait.extension; -import com.github.bsideup.jabel.Desugar; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; @@ -56,23 +55,23 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris().values()); - builder.addAllExtensions(simpleExtensions.extensionList()); + builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensions(simpleExtensions.extensionList); } public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris().values()); - builder.addAllExtensions(simpleExtensions.extensionList()); + builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensions(simpleExtensions.extensionList); } private SimpleExtensions getExtensions() { - var uriPos = new AtomicInteger(1); - var uris = new HashMap(); + AtomicInteger uriPos = new AtomicInteger(1); + HashMap uris = new HashMap<>(); - var extensionList = new ArrayList(); - for (var e : funcMap.forwardMap.entrySet()) { + ArrayList extensionList = new ArrayList<>(); + for (Map.Entry e : funcMap.forwardMap.entrySet()) { SimpleExtensionURI uri = uris.computeIfAbsent( e.getValue().namespace(), @@ -81,7 +80,7 @@ private SimpleExtensions getExtensions() { .setExtensionUriAnchor(uriPos.getAndIncrement()) .setUri(k) .build()); - var decl = + SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() .setExtensionFunction( SimpleExtensionDeclaration.ExtensionFunction.newBuilder() @@ -91,7 +90,7 @@ private SimpleExtensions getExtensions() { .build(); extensionList.add(decl); } - for (var e : typeMap.forwardMap.entrySet()) { + for (Map.Entry e : typeMap.forwardMap.entrySet()) { SimpleExtensionURI uri = uris.computeIfAbsent( e.getValue().namespace(), @@ -100,7 +99,7 @@ private SimpleExtensions getExtensions() { .setExtensionUriAnchor(uriPos.getAndIncrement()) .setUri(k) .build()); - var decl = + SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() .setExtensionType( SimpleExtensionDeclaration.ExtensionType.newBuilder() @@ -113,10 +112,17 @@ private SimpleExtensions getExtensions() { return new SimpleExtensions(uris, extensionList); } - @Desugar - private record SimpleExtensions( - HashMap uris, - ArrayList extensionList) {} + private static final class SimpleExtensions { + final HashMap uris; + final ArrayList extensionList; + + SimpleExtensions( + HashMap uris, + ArrayList extensionList) { + this.uris = uris; + this.extensionList = extensionList; + } + } /** We don't depend on guava... */ private static class BidiMap { diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 70034d9b1..7e1c81417 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -46,12 +46,12 @@ private Builder from( List simpleExtensionURIs, List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); - for (var extension : simpleExtensionURIs) { + for (SimpleExtensionURI extension : simpleExtensionURIs) { namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap - for (var extension : simpleExtensionDeclarations) { + for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { continue; } @@ -68,7 +68,7 @@ private Builder from( } // Add all types used in plan to the typeMap - for (var extension : simpleExtensionDeclarations) { + for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { if (!extension.hasExtensionType()) { continue; } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 56cf2f2ba..5c090ce1c 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -271,9 +271,7 @@ public FunctionAnchor getAnchor() { private final Supplier> requiredArgsSupplier = Util.memoize( () -> { - return args().stream() - .filter(Argument::required) - .collect(java.util.stream.Collectors.toList()); + return args().stream().filter(Argument::required).collect(Collectors.toList()); }); public static String constructKeyFromTypes( @@ -656,7 +654,7 @@ private void checkNamespace(String name) { } public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { - var variant = aggregateFunctionsLookup.get().get(anchor); + AggregateFunctionVariant variant = aggregateFunctionsLookup.get().get(anchor); if (variant != null) { return variant; } @@ -670,7 +668,7 @@ public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { } public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { - var variant = windowFunctionsLookup.get().get(anchor); + WindowFunctionVariant variant = windowFunctionsLookup.get().get(anchor); if (variant != null) { return variant; } @@ -697,7 +695,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { } public static ExtensionCollection loadDefaults() { - var defaultFiles = + List defaultFiles = Arrays.asList( "boolean", "aggregate_generic", @@ -712,7 +710,7 @@ public static ExtensionCollection loadDefaults() { "string") .stream() .map(c -> String.format("/functions_%s.yaml", c)) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); return load(defaultFiles); } @@ -722,17 +720,17 @@ public static ExtensionCollection load(List resourcePaths) { throw new IllegalArgumentException("Require at least one resource path."); } - var extensions = + List extensions = resourcePaths.stream() .map( path -> { - try (var stream = ExtensionCollection.class.getResourceAsStream(path)) { + try (InputStream stream = ExtensionCollection.class.getResourceAsStream(path)) { return load(path, stream); } catch (IOException e) { throw new RuntimeException(e); } }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); ExtensionCollection complete = extensions.get(0); for (int i = 1; i < extensions.size(); i++) { complete = complete.merge(extensions.get(i)); @@ -742,7 +740,7 @@ public static ExtensionCollection load(List resourcePaths) { public static ExtensionCollection load(String namespace, String str) { try { - var doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class); + ExtensionSignatures doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class); return buildExtensionCollection(namespace, doc); } catch (JsonProcessingException e) { throw new RuntimeException(e); @@ -751,7 +749,8 @@ public static ExtensionCollection load(String namespace, String str) { public static ExtensionCollection load(String namespace, InputStream stream) { try { - var doc = objectMapper(namespace).readValue(stream, ExtensionSignatures.class); + ExtensionSignatures doc = + objectMapper(namespace).readValue(stream, ExtensionSignatures.class); return buildExtensionCollection(namespace, doc); } catch (RuntimeException ex) { throw ex; @@ -765,12 +764,12 @@ public static ExtensionCollection buildExtensionCollection( List scalarFunctionVariants = extensionSignatures.scalars().stream() .flatMap(t -> t.resolve(namespace)) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); List aggregateFunctionVariants = extensionSignatures.aggregates().stream() .flatMap(t -> t.resolve(namespace)) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); Stream windowFunctionVariants = extensionSignatures.windows().stream().flatMap(t -> t.resolve(namespace)); @@ -794,7 +793,7 @@ public static ExtensionCollection buildExtensionCollection( Stream.concat(windowFunctionVariants, windowAggFunctionVariants) .collect(Collectors.toList()); - var collection = + ImmutableSimpleExtension.ExtensionCollection collection = ImmutableSimpleExtension.ExtensionCollection.builder() .scalarFunctions(scalarFunctionVariants) .aggregateFunctions(aggregateFunctionVariants) diff --git a/core/src/main/java/io/substrait/hint/Hint.java b/core/src/main/java/io/substrait/hint/Hint.java index 238bf44e7..7497f0529 100644 --- a/core/src/main/java/io/substrait/hint/Hint.java +++ b/core/src/main/java/io/substrait/hint/Hint.java @@ -12,7 +12,8 @@ public abstract class Hint { public abstract List getOutputNames(); public RelCommon.Hint toProto() { - var builder = RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames()); + RelCommon.Hint.Builder builder = + RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames()); getAlias().ifPresent(builder::setAlias); return builder.build(); } diff --git a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java index 9accc5de2..8d7947096 100644 --- a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java @@ -33,7 +33,7 @@ public DdlRel.DdlObject toProto() { } public static DdlObject fromProto(DdlRel.DdlObject proto) { - for (var v : values()) { + for (DdlObject v : values()) { if (v.proto == proto) { return v; } @@ -61,7 +61,7 @@ public DdlRel.DdlOp toProto() { } public static DdlOp fromProto(DdlRel.DdlOp proto) { - for (var v : values()) { + for (DdlOp v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/relation/AbstractWriteRel.java b/core/src/main/java/io/substrait/relation/AbstractWriteRel.java index 23014724f..fe7f49bee 100644 --- a/core/src/main/java/io/substrait/relation/AbstractWriteRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractWriteRel.java @@ -33,7 +33,7 @@ public WriteRel.WriteOp toProto() { } public static WriteOp fromProto(WriteRel.WriteOp proto) { - for (var v : values()) { + for (WriteOp v : values()) { if (v.proto == proto) { return v; } @@ -60,7 +60,7 @@ public WriteRel.CreateMode toProto() { } public static CreateMode fromProto(WriteRel.CreateMode proto) { - for (var v : values()) { + for (CreateMode v : values()) { if (v.proto == proto) { return v; } @@ -85,7 +85,7 @@ public WriteRel.OutputMode toProto() { } public static OutputMode fromProto(WriteRel.OutputMode proto) { - for (var v : values()) { + for (OutputMode v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java index 7a9d0f569..08776ed5d 100644 --- a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java @@ -3,9 +3,13 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.proto.AggregateFunction; +import io.substrait.proto.FunctionArgument; import io.substrait.type.proto.TypeProtoConverter; import io.substrait.util.EmptyVisitationContext; +import java.util.List; +import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -25,9 +29,10 @@ public AggregateFunctionProtoConverter(ExtensionCollector functionCollector) { } public AggregateFunction toProto(Aggregate.Measure measure) { - var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - var args = measure.getFunction().arguments(); - var aggFuncDef = measure.getFunction().declaration(); + FunctionArg.FuncArgVisitor + argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + List args = measure.getFunction().arguments(); + SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); return AggregateFunction.newBuilder() .setPhase(measure.getFunction().aggregationPhase().toProto()) @@ -39,7 +44,7 @@ public AggregateFunction toProto(Aggregate.Measure measure) { i -> args.get(i) .accept(aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .setFunctionReference( functionCollector.getFunctionReference(measure.getFunction().declaration())) .build(); diff --git a/core/src/main/java/io/substrait/relation/Expand.java b/core/src/main/java/io/substrait/relation/Expand.java index 63e868f63..1d07ea6ca 100644 --- a/core/src/main/java/io/substrait/relation/Expand.java +++ b/core/src/main/java/io/substrait/relation/Expand.java @@ -53,8 +53,8 @@ public abstract static class SwitchingField implements ExpandField { public abstract List getDuplicates(); public Type getType() { - var nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable()); - var type = getDuplicates().get(0).getType(); + boolean nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable()); + Type type = getDuplicates().get(0).getType(); return nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type); } diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 8709f9d02..57132a940 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -212,9 +212,10 @@ public Optional visit( @Override public Optional visit(Expression.Switch expr, EmptyVisitationContext context) throws E { - var match = expr.match().accept(this, context); - var switchClauses = transformList(expr.switchClauses(), context, this::visitSwitchClause); - var defaultClause = expr.defaultClause().accept(this, context); + Optional match = expr.match().accept(this, context); + Optional> switchClauses = + transformList(expr.switchClauses(), context, this::visitSwitchClause); + Optional defaultClause = expr.defaultClause().accept(this, context); if (allEmpty(match, switchClauses, defaultClause)) { return Optional.empty(); @@ -242,8 +243,9 @@ protected Optional visitSwitchClause( @Override public Optional visit(Expression.IfThen ifThen, EmptyVisitationContext context) throws E { - var ifClauses = transformList(ifThen.ifClauses(), context, this::visitIfClause); - var elseClause = ifThen.elseClause().accept(this, context); + Optional> ifClauses = + transformList(ifThen.ifClauses(), context, this::visitIfClause); + Optional elseClause = ifThen.elseClause().accept(this, context); if (allEmpty(ifClauses, elseClause)) { return Optional.empty(); @@ -258,8 +260,8 @@ public Optional visit(Expression.IfThen ifThen, EmptyVisitationConte protected Optional visitIfClause( Expression.IfClause ifClause, EmptyVisitationContext context) throws E { - var condition = ifClause.condition().accept(this, context); - var then = ifClause.then().accept(this, context); + Optional condition = ifClause.condition().accept(this, context); + Optional then = ifClause.then().accept(this, context); if (allEmpty(condition, then)) { return Optional.empty(); @@ -287,9 +289,10 @@ public Optional visit( @Override public Optional visit( Expression.WindowFunctionInvocation wfi, EmptyVisitationContext context) throws E { - var arguments = visitFunctionArguments(wfi.arguments(), context); - var partitionBy = visitExprList(wfi.partitionBy(), context); - var sort = transformList(wfi.sort(), context, this::visitSortField); + Optional> arguments = visitFunctionArguments(wfi.arguments(), context); + Optional> partitionBy = visitExprList(wfi.partitionBy(), context); + Optional> sort = + transformList(wfi.sort(), context, this::visitSortField); if (allEmpty(arguments, partitionBy, sort)) { return Optional.empty(); @@ -313,8 +316,8 @@ public Optional visit(Expression.Cast cast, EmptyVisitationContext c @Override public Optional visit( Expression.SingleOrList singleOrList, EmptyVisitationContext context) throws E { - var condition = singleOrList.condition().accept(this, context); - var options = visitExprList(singleOrList.options(), context); + Optional condition = singleOrList.condition().accept(this, context); + Optional> options = visitExprList(singleOrList.options(), context); if (allEmpty(condition, options)) { return Optional.empty(); @@ -330,8 +333,8 @@ public Optional visit( @Override public Optional visit( Expression.MultiOrList multiOrList, EmptyVisitationContext context) throws E { - var conditions = visitExprList(multiOrList.conditions(), context); - var optionCombinations = + Optional> conditions = visitExprList(multiOrList.conditions(), context); + Optional> optionCombinations = transformList(multiOrList.optionCombinations(), context, this::visitMultiOrListRecord); if (allEmpty(conditions, optionCombinations)) { @@ -359,7 +362,8 @@ protected Optional visitMultiOrListRecord( @Override public Optional visit(FieldReference fieldReference, EmptyVisitationContext context) throws E { - var inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); + Optional inputExpression = + visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { return Optional.empty(); @@ -389,8 +393,8 @@ public Optional visit( @Override public Optional visit( Expression.InPredicate inPredicate, EmptyVisitationContext context) throws E { - var haystack = inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context); - var needles = visitExprList(inPredicate.needles(), context); + Optional haystack = inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context); + Optional> needles = visitExprList(inPredicate.needles(), context); if (allEmpty(haystack, needles)) { return Optional.empty(); @@ -425,8 +429,8 @@ protected Optional> visitFunctionArguments( funcArgs, context, (arg, c) -> { - if (arg instanceof Expression expr) { - return expr.accept(this, c).flatMap(Optional::of); + if (arg instanceof Expression) { + return ((Expression) arg).accept(this, c).flatMap(Optional::of); } else { return Optional.empty(); } diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index 490bd7315..205e51572 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -51,7 +51,7 @@ public JoinRel.JoinType toProto() { } public static JoinType fromProto(JoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -63,33 +63,51 @@ public static JoinType fromProto(JoinRel.JoinType proto) { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER, RIGHT_SINGLE -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_SEMI, RIGHT_ANTI -> Stream - .of(); // these are right joins which ignore left side columns - case RIGHT_MARK -> Stream.of( - TypeCreator.REQUIRED - .BOOLEAN); // right mark join keeps all fields from right and adds a boolean mark - // field - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER, LEFT_SINGLE -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case SEMI, ANTI, LEFT_SEMI, LEFT_ANTI -> Stream - .of(); // these are left joins which ignore right side columns - case LEFT_MARK -> Stream.of( - TypeCreator.REQUIRED - .BOOLEAN); // left mark join keeps all fields from left and adds a boolean mark - // field - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case RIGHT: + case OUTER: + case RIGHT_SINGLE: + return getLeft().getRecordType().fields().stream().map(TypeCreator::asNullable); + case RIGHT_SEMI: + case RIGHT_ANTI: + return Stream.of(); // these are right joins which ignore left side columns + case RIGHT_MARK: + return Stream.of( + TypeCreator.REQUIRED + .BOOLEAN); // right mark join keeps all fields from right and adds a boolean mark + // field + default: + return getLeft().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + case LEFT_SINGLE: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case SEMI: + case ANTI: + case LEFT_SEMI: + case LEFT_ANTI: + return Stream.of(); // these are left joins which ignore right side columns + case LEFT_MARK: + return Stream.of( + TypeCreator.REQUIRED + .BOOLEAN); // left mark join keeps all fields from left and adds a boolean mark + // field + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index c359ea420..f005a29be 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -67,77 +67,56 @@ public Plan.Root from(io.substrait.proto.RelRoot rel) { } public Rel from(io.substrait.proto.Rel rel) { - var relType = rel.getRelTypeCase(); + io.substrait.proto.Rel.RelTypeCase relType = rel.getRelTypeCase(); switch (relType) { - case READ -> { + case READ: return newRead(rel.getRead()); - } - case FILTER -> { + case FILTER: return newFilter(rel.getFilter()); - } - case FETCH -> { + case FETCH: return newFetch(rel.getFetch()); - } - case AGGREGATE -> { + case AGGREGATE: return newAggregate(rel.getAggregate()); - } - case SORT -> { + case SORT: return newSort(rel.getSort()); - } - case JOIN -> { + case JOIN: return newJoin(rel.getJoin()); - } - case SET -> { + case SET: return newSet(rel.getSet()); - } - case PROJECT -> { + case PROJECT: return newProject(rel.getProject()); - } - case EXPAND -> { + case EXPAND: return newExpand(rel.getExpand()); - } - case CROSS -> { + case CROSS: return newCross(rel.getCross()); - } - case EXTENSION_LEAF -> { + case EXTENSION_LEAF: return newExtensionLeaf(rel.getExtensionLeaf()); - } - case EXTENSION_SINGLE -> { + case EXTENSION_SINGLE: return newExtensionSingle(rel.getExtensionSingle()); - } - case EXTENSION_MULTI -> { + case EXTENSION_MULTI: return newExtensionMulti(rel.getExtensionMulti()); - } - case HASH_JOIN -> { + case HASH_JOIN: return newHashJoin(rel.getHashJoin()); - } - case MERGE_JOIN -> { + case MERGE_JOIN: return newMergeJoin(rel.getMergeJoin()); - } - case NESTED_LOOP_JOIN -> { + case NESTED_LOOP_JOIN: return newNestedLoopJoin(rel.getNestedLoopJoin()); - } - case WINDOW -> { + case WINDOW: return newConsistentPartitionWindow(rel.getWindow()); - } - case WRITE -> { + case WRITE: return newWrite(rel.getWrite()); - } - case DDL -> { + case DDL: return newDdl(rel.getDdl()); - } - case UPDATE -> { + case UPDATE: return newUpdate(rel.getUpdate()); - } - default -> { + default: throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType); - } } } protected Rel newRead(ReadRel rel) { if (rel.hasVirtualTable()) { - var virtualTable = rel.getVirtualTable(); + ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); if (virtualTable.getValuesCount() == 0) { return newEmptyScan(rel); } else { @@ -155,21 +134,20 @@ protected Rel newRead(ReadRel rel) { } protected Rel newWrite(WriteRel rel) { - var relType = rel.getWriteTypeCase(); + WriteRel.WriteTypeCase relType = rel.getWriteTypeCase(); switch (relType) { - case NAMED_TABLE -> { + case NAMED_TABLE: return newNamedWrite(rel); - } - case EXTENSION_TABLE -> { + case EXTENSION_TABLE: return newExtensionWrite(rel); - } - default -> throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); + default: + throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); } } protected NamedWrite newNamedWrite(WriteRel rel) { - var input = from(rel.getInput()); - var builder = + Rel input = from(rel.getInput()); + ImmutableNamedWrite.Builder builder = NamedWrite.builder() .input(input) .names(rel.getNamedTable().getNamesList()) @@ -186,9 +164,10 @@ protected NamedWrite newNamedWrite(WriteRel rel) { } protected Rel newExtensionWrite(WriteRel rel) { - var input = from(rel.getInput()); - var detail = detailFromWriteExtensionObject(rel.getExtensionTable().getDetail()); - var builder = + Rel input = from(rel.getInput()); + Extension.WriteExtensionObject detail = + detailFromWriteExtensionObject(rel.getExtensionTable().getDetail()); + ImmutableExtensionWrite.Builder builder = ExtensionWrite.builder() .input(input) .detail(detail) @@ -205,20 +184,19 @@ protected Rel newExtensionWrite(WriteRel rel) { } protected Rel newDdl(DdlRel rel) { - var relType = rel.getWriteTypeCase(); + DdlRel.WriteTypeCase relType = rel.getWriteTypeCase(); switch (relType) { - case NAMED_OBJECT -> { + case NAMED_OBJECT: return newNamedDdl(rel); - } - case EXTENSION_OBJECT -> { + case EXTENSION_OBJECT: return newExtensionDdl(rel); - } - default -> throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); + default: + throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); } } protected NamedDdl newNamedDdl(DdlRel rel) { - var tableSchema = newNamedStruct(rel.getTableSchema()); + NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); return NamedDdl.builder() .names(rel.getNamedObject().getNamesList()) .tableSchema(tableSchema) @@ -233,8 +211,9 @@ protected NamedDdl newNamedDdl(DdlRel rel) { } protected ExtensionDdl newExtensionDdl(DdlRel rel) { - var detail = detailFromDdlExtensionObject(rel.getExtensionObject().getDetail()); - var tableSchema = newNamedStruct(rel.getTableSchema()); + Extension.DdlExtensionObject detail = + detailFromDdlExtensionObject(rel.getExtensionObject().getDetail()); + NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); return ExtensionDdl.builder() .detail(detail) .tableSchema(newNamedStruct(rel.getTableSchema())) @@ -254,7 +233,8 @@ protected Optional optionalViewDefinition(DdlRel rel) { protected Expression.StructLiteral tableDefaults( io.substrait.proto.Expression.Literal.Struct struct, NamedStruct tableSchema) { - var converter = new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); return Expression.StructLiteral.builder() .fields( struct.getFieldsList().stream() @@ -264,29 +244,29 @@ protected Expression.StructLiteral tableDefaults( } protected Rel newUpdate(UpdateRel rel) { - var relType = rel.getUpdateTypeCase(); + UpdateRel.UpdateTypeCase relType = rel.getUpdateTypeCase(); switch (relType) { - case NAMED_TABLE -> { + case NAMED_TABLE: return newNamedUpdate(rel); - } - default -> throw new UnsupportedOperationException( - "Unsupported UpdateTypeCase of " + relType); + default: + throw new UnsupportedOperationException("Unsupported UpdateTypeCase of " + relType); } } protected Rel newNamedUpdate(UpdateRel rel) { - var tableSchema = newNamedStruct(rel.getTableSchema()); - var converter = new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); + NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); List transformations = new ArrayList<>(rel.getTransformationsCount()); - for (var transformation : rel.getTransformationsList()) { + for (UpdateRel.TransformExpression transformation : rel.getTransformationsList()) { transformations.add( NamedUpdate.TransformExpression.builder() .transformation(converter.from(transformation.getTransformation())) .columnTarget(transformation.getColumnTarget()) .build()); } - var builder = + ImmutableNamedUpdate.Builder builder = NamedUpdate.builder() .names(rel.getNamedTable().getNamesList()) .tableSchema(tableSchema) @@ -299,8 +279,8 @@ protected Rel newNamedUpdate(UpdateRel rel) { } protected Filter newFilter(FilterRel rel) { - var input = from(rel.getInput()); - var builder = + Rel input = from(rel.getInput()); + ImmutableFilter.Builder builder = Filter.builder() .input(input) .condition( @@ -321,7 +301,7 @@ protected NamedStruct newNamedStruct(ReadRel rel) { } protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) { - var struct = namedStruct.getStruct(); + io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); return NamedStruct.builder() .names(namedStruct.getNamesList()) .struct( @@ -336,8 +316,8 @@ protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) } protected EmptyScan newEmptyScan(ReadRel rel) { - var namedStruct = newNamedStruct(rel); - var builder = + NamedStruct namedStruct = newNamedStruct(rel); + ImmutableEmptyScan.Builder builder = EmptyScan.builder() .initialSchema(namedStruct) .bestEffortFilter( @@ -367,7 +347,7 @@ protected EmptyScan newEmptyScan(ReadRel rel) { protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) { Extension.LeafRelDetail detail = detailFromExtensionLeafRel(rel.getDetail()); - var builder = + ImmutableExtensionLeaf.Builder builder = ExtensionLeaf.from(detail) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -378,7 +358,7 @@ protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) { protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) { Extension.SingleRelDetail detail = detailFromExtensionSingleRel(rel.getDetail()); Rel input = from(rel.getInput()); - var builder = + ImmutableExtensionSingle.Builder builder = ExtensionSingle.from(detail, input) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -389,7 +369,7 @@ protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) { protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) { Extension.MultiRelDetail detail = detailFromExtensionMultiRel(rel.getDetail()); List inputs = rel.getInputsList().stream().map(this::from).collect(Collectors.toList()); - var builder = + ImmutableExtensionMulti.Builder builder = ExtensionMulti.from(detail, inputs) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -401,8 +381,8 @@ protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) { } protected NamedScan newNamedScan(ReadRel rel) { - var namedStruct = newNamedStruct(rel); - var builder = + NamedStruct namedStruct = newNamedStruct(rel); + ImmutableNamedScan.Builder builder = NamedScan.builder() .initialSchema(namedStruct) .names(rel.getNamedTable().getNamesList()) @@ -434,7 +414,7 @@ protected NamedScan newNamedScan(ReadRel rel) { protected ExtensionTable newExtensionTable(ReadRel rel) { Extension.ExtensionTableDetail detail = detailFromExtensionTable(rel.getExtensionTable().getDetail()); - var builder = ExtensionTable.from(detail); + ImmutableExtensionTable.Builder builder = ExtensionTable.from(detail); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -447,9 +427,9 @@ protected ExtensionTable newExtensionTable(ReadRel rel) { } protected LocalFiles newLocalFiles(ReadRel rel) { - var namedStruct = newNamedStruct(rel); + NamedStruct namedStruct = newNamedStruct(rel); - var builder = + ImmutableLocalFiles.Builder builder = LocalFiles.builder() .initialSchema(namedStruct) .addAllItems( @@ -482,7 +462,7 @@ protected LocalFiles newLocalFiles(ReadRel rel) { } protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { - var builder = + io.substrait.relation.files.ImmutableFileOrFiles.Builder builder = FileOrFiles.builder() .partitionIndex(file.getPartitionIndex()) .start(file.getStart()) @@ -496,13 +476,14 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { } else if (file.hasDwrf()) { builder.fileFormat(FileFormat.DwrfReadOptions.builder().build()); } else if (file.hasText()) { - var ffBuilder = - FileFormat.DelimiterSeparatedTextReadOptions.builder() - .fieldDelimiter(file.getText().getFieldDelimiter()) - .maxLineSize(file.getText().getMaxLineSize()) - .quote(file.getText().getQuote()) - .headerLinesToSkip(file.getText().getHeaderLinesToSkip()) - .escape(file.getText().getEscape()); + io.substrait.relation.files.ImmutableFileFormat.DelimiterSeparatedTextReadOptions.Builder + ffBuilder = + FileFormat.DelimiterSeparatedTextReadOptions.builder() + .fieldDelimiter(file.getText().getFieldDelimiter()) + .maxLineSize(file.getText().getMaxLineSize()) + .quote(file.getText().getQuote()) + .headerLinesToSkip(file.getText().getHeaderLinesToSkip()) + .escape(file.getText().getEscape()); if (file.getText().hasValueTreatedAsNull()) { ffBuilder.valueTreatedAsNull(file.getText().getValueTreatedAsNull()); } @@ -523,12 +504,12 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { } protected VirtualTableScan newVirtualTable(ReadRel rel) { - var virtualTable = rel.getVirtualTable(); - var virtualTableSchema = newNamedStruct(rel); - var converter = + ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); + NamedStruct virtualTableSchema = newNamedStruct(rel); + ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this); List structLiterals = new ArrayList<>(virtualTable.getValuesCount()); - for (var struct : virtualTable.getValuesList()) { + for (io.substrait.proto.Expression.Literal.Struct struct : virtualTable.getValuesList()) { structLiterals.add( Expression.StructLiteral.builder() .fields( @@ -538,7 +519,7 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { .build()); } - var builder = + ImmutableVirtualTableScan.Builder builder = VirtualTableScan.builder() .bestEffortFilter( Optional.ofNullable( @@ -558,8 +539,8 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { } protected Fetch newFetch(FetchRel rel) { - var input = from(rel.getInput()); - var builder = Fetch.builder().input(input).offset(rel.getOffset()); + Rel input = from(rel.getInput()); + ImmutableFetch.Builder builder = Fetch.builder().input(input).offset(rel.getOffset()); if (rel.getCount() != -1) { // -1 is used as a sentinel value to signal LIMIT ALL // count only needs to be set when it is not -1 @@ -577,9 +558,10 @@ protected Fetch newFetch(FetchRel rel) { } protected Project newProject(ProjectRel rel) { - var input = from(rel.getInput()); - var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var builder = + Rel input = from(rel.getInput()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + ImmutableProject.Builder builder = Project.builder() .input(input) .expressions( @@ -598,28 +580,33 @@ protected Project newProject(ProjectRel rel) { } protected Expand newExpand(ExpandRel rel) { - var input = from(rel.getInput()); - var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var builder = + Rel input = from(rel.getInput()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + ImmutableExpand.Builder builder = Expand.builder() .input(input) .fields( rel.getFieldsList().stream() .map( - expandField -> - switch (expandField.getFieldTypeCase()) { - case CONSISTENT_FIELD -> Expand.ConsistentField.builder() + expandField -> { + switch (expandField.getFieldTypeCase()) { + case CONSISTENT_FIELD: + return Expand.ConsistentField.builder() .expression(converter.from(expandField.getConsistentField())) .build(); - case SWITCHING_FIELD -> Expand.SwitchingField.builder() + case SWITCHING_FIELD: + return Expand.SwitchingField.builder() .duplicates( expandField.getSwitchingField().getDuplicatesList().stream() .map(converter::from) .collect(java.util.stream.Collectors.toList())) .build(); - case FIELDTYPE_NOT_SET -> throw new UnsupportedOperationException( - "Expand fields not set"); - }) + case FIELDTYPE_NOT_SET: + default: + throw new UnsupportedOperationException("Expand fields not set"); + } + }) .collect(java.util.stream.Collectors.toList())); builder @@ -630,14 +617,14 @@ protected Expand newExpand(ExpandRel rel) { } protected Aggregate newAggregate(AggregateRel rel) { - var input = from(rel.getInput()); - var protoExprConverter = + Rel input = from(rel.getInput()); + ProtoExpressionConverter protoExprConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var protoAggrFuncConverter = + ProtoAggregateFunctionConverter protoAggrFuncConverter = new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter); List groupings = new ArrayList<>(rel.getGroupingsCount()); - for (var grouping : rel.getGroupingsList()) { + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { groupings.add( Aggregate.Grouping.builder() .expressions( @@ -647,11 +634,12 @@ protected Aggregate newAggregate(AggregateRel rel) { .build()); } List measures = new ArrayList<>(rel.getMeasuresCount()); - var pF = new FunctionArg.ProtoFrom(protoExprConverter, protoTypeConverter); - for (var measure : rel.getMeasuresList()) { - var func = measure.getMeasure(); - var funcDecl = lookup.getAggregateFunction(func.getFunctionReference(), extensions); - var args = + FunctionArg.ProtoFrom pF = new FunctionArg.ProtoFrom(protoExprConverter, protoTypeConverter); + for (AggregateRel.Measure measure : rel.getMeasuresList()) { + io.substrait.proto.AggregateFunction func = measure.getMeasure(); + SimpleExtension.AggregateFunctionVariant funcDecl = + lookup.getAggregateFunction(func.getFunctionReference(), extensions); + List args = IntStream.range(0, measure.getMeasure().getArgumentsCount()) .mapToObj(i -> pF.convert(funcDecl, i, measure.getMeasure().getArguments(i))) .collect(java.util.stream.Collectors.toList()); @@ -663,7 +651,8 @@ protected Aggregate newAggregate(AggregateRel rel) { measure.hasFilter() ? protoExprConverter.from(measure.getFilter()) : null)) .build()); } - var builder = Aggregate.builder().input(input).groupings(groupings).measures(measures); + ImmutableAggregate.Builder builder = + Aggregate.builder().input(input).groupings(groupings).measures(measures); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -676,9 +665,10 @@ protected Aggregate newAggregate(AggregateRel rel) { } protected Sort newSort(SortRel rel) { - var input = from(rel.getInput()); - var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var builder = + Rel input = from(rel.getInput()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + ImmutableSort.Builder builder = Sort.builder() .input(input) .sortFields( @@ -707,8 +697,9 @@ protected Join newJoin(JoinRel rel) { Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + ImmutableJoin.Builder builder = Join.builder() .left(left) .right(right) @@ -731,7 +722,7 @@ protected Join newJoin(JoinRel rel) { protected Rel newCross(CrossRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - var builder = Cross.builder().left(left).right(right); + ImmutableCross.Builder builder = Cross.builder().left(left).right(right); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -747,7 +738,8 @@ protected Set newSet(SetRel rel) { rel.getInputsList().stream() .map(inputRel -> from(inputRel)) .collect(java.util.stream.Collectors.toList()); - var builder = Set.builder().inputs(inputs).setOp(Set.SetOp.fromProto(rel.getOp())); + ImmutableSet.Builder builder = + Set.builder().inputs(inputs).setOp(Set.SetOp.fromProto(rel.getOp())); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -762,16 +754,19 @@ protected Set newSet(SetRel rel) { protected Rel newHashJoin(HashJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - var leftKeys = rel.getLeftKeysList(); - var rightKeys = rel.getRightKeysList(); + List leftKeys = rel.getLeftKeysList(); + List rightKeys = rel.getRightKeysList(); Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); - var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); - var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter leftConverter = + new ProtoExpressionConverter(lookup, extensions, leftStruct, this); + ProtoExpressionConverter rightConverter = + new ProtoExpressionConverter(lookup, extensions, rightStruct, this); + ProtoExpressionConverter unionConverter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + io.substrait.relation.physical.ImmutableHashJoin.Builder builder = HashJoin.builder() .left(left) .right(right) @@ -794,16 +789,19 @@ protected Rel newHashJoin(HashJoinRel rel) { protected Rel newMergeJoin(MergeJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - var leftKeys = rel.getLeftKeysList(); - var rightKeys = rel.getRightKeysList(); + List leftKeys = rel.getLeftKeysList(); + List rightKeys = rel.getRightKeysList(); Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); - var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); - var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter leftConverter = + new ProtoExpressionConverter(lookup, extensions, leftStruct, this); + ProtoExpressionConverter rightConverter = + new ProtoExpressionConverter(lookup, extensions, rightStruct, this); + ProtoExpressionConverter unionConverter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + io.substrait.relation.physical.ImmutableMergeJoin.Builder builder = MergeJoin.builder() .left(left) .right(right) @@ -830,8 +828,9 @@ protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + io.substrait.relation.physical.ImmutableNestedLoopJoin.Builder builder = NestedLoopJoin.builder() .left(left) .right(right) @@ -855,24 +854,24 @@ protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { protected ConsistentPartitionWindow newConsistentPartitionWindow( ConsistentPartitionWindowRel rel) { - var input = from(rel.getInput()); - var protoExpressionConverter = + Rel input = from(rel.getInput()); + ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var partitionExprs = + List partitionExprs = rel.getPartitionExpressionsList().stream() .map(protoExpressionConverter::from) .collect(Collectors.toList()); - var sortFields = + List sortFields = rel.getSortsList().stream() .map(protoExpressionConverter::fromSortField) .collect(Collectors.toList()); - var windowRelFunctions = + List windowRelFunctions = rel.getWindowFunctionsList().stream() .map(protoExpressionConverter::fromWindowRelFunction) .collect(Collectors.toList()); - var builder = + ImmutableConsistentPartitionWindow.Builder builder = ConsistentPartitionWindow.builder() .input(input) .partitionExpressions(partitionExprs) @@ -896,8 +895,9 @@ protected static Optional optionalRelmap(io.substrait.proto.RelCommon protected static Optional optionalHint(io.substrait.proto.RelCommon relCommon) { if (!relCommon.hasHint()) return Optional.empty(); - var hint = relCommon.getHint(); - var builder = Hint.builder().addAllOutputNames(hint.getOutputNamesList()); + io.substrait.proto.RelCommon.Hint hint = relCommon.getHint(); + io.substrait.hint.ImmutableHint.Builder builder = + Hint.builder().addAllOutputNames(hint.getOutputNamesList()); if (!hint.getAlias().isEmpty()) { builder.alias(hint.getAlias()); } @@ -914,7 +914,7 @@ protected Optional optionalAdvancedExtension( protected AdvancedExtension advancedExtension( io.substrait.proto.AdvancedExtension advancedExtension) { - var builder = AdvancedExtension.builder(); + io.substrait.extension.ImmutableAdvancedExtension.Builder builder = AdvancedExtension.builder(); if (advancedExtension.hasEnhancement()) { builder.enhancement(enhancementFromAdvancedExtension(advancedExtension.getEnhancement())); } diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 87d99ba72..f5a3e8081 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -46,9 +46,11 @@ protected ExpressionCopyOnWriteVisitor getExpressionCopyOnWriteVisitor() { @Override public Optional visit(Aggregate aggregate, EmptyVisitationContext context) throws E { - var input = aggregate.getInput().accept(this, context); - var groupings = transformList(aggregate.getGroupings(), context, this::visitGrouping); - var measures = transformList(aggregate.getMeasures(), context, this::visitMeasure); + Optional input = aggregate.getInput().accept(this, context); + Optional> groupings = + transformList(aggregate.getGroupings(), context, this::visitGrouping); + Optional> measures = + transformList(aggregate.getMeasures(), context, this::visitMeasure); if (allEmpty(input, groupings, measures)) { return Optional.empty(); @@ -70,8 +72,10 @@ protected Optional visitGrouping( protected Optional visitMeasure( Aggregate.Measure measure, EmptyVisitationContext context) throws E { - var preMeasureFilter = visitOptionalExpression(measure.getPreMeasureFilter(), context); - var afi = visitAggregateFunction(measure.getFunction(), context); + Optional preMeasureFilter = + visitOptionalExpression(measure.getPreMeasureFilter(), context); + Optional afi = + visitAggregateFunction(measure.getFunction(), context); if (allEmpty(preMeasureFilter, afi)) { return Optional.empty(); @@ -86,8 +90,9 @@ protected Optional visitMeasure( protected Optional visitAggregateFunction( AggregateFunctionInvocation afi, EmptyVisitationContext context) throws E { - var arguments = visitFunctionArguments(afi.arguments(), context); - var sort = transformList(afi.sort(), context, this::visitSortField); + Optional> arguments = visitFunctionArguments(afi.arguments(), context); + Optional> sort = + transformList(afi.sort(), context, this::visitSortField); if (allEmpty(arguments, sort)) { return Optional.empty(); @@ -124,8 +129,9 @@ public Optional visit(Fetch fetch, EmptyVisitationContext context) throws E @Override public Optional visit(Filter filter, EmptyVisitationContext context) throws E { - var input = filter.getInput().accept(this, context); - var condition = filter.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); + Optional input = filter.getInput().accept(this, context); + Optional condition = + filter.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(input, condition)) { return Optional.empty(); @@ -140,10 +146,10 @@ public Optional visit(Filter filter, EmptyVisitationContext context) throws @Override public Optional visit(Join join, EmptyVisitationContext context) throws E { - var left = join.getLeft().accept(this, context); - var right = join.getRight().accept(this, context); - var condition = visitOptionalExpression(join.getCondition(), context); - var postFilter = visitOptionalExpression(join.getPostJoinFilter(), context); + Optional left = join.getLeft().accept(this, context); + Optional right = join.getRight().accept(this, context); + Optional condition = visitOptionalExpression(join.getCondition(), context); + Optional postFilter = visitOptionalExpression(join.getPostJoinFilter(), context); if (allEmpty(left, right, condition, postFilter)) { return Optional.empty(); @@ -166,7 +172,7 @@ public Optional visit(Set set, EmptyVisitationContext context) throws E { @Override public Optional visit(NamedScan namedScan, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(namedScan.getFilter(), context); + Optional filter = visitOptionalExpression(namedScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -177,7 +183,7 @@ public Optional visit(NamedScan namedScan, EmptyVisitationContext context) @Override public Optional visit(LocalFiles localFiles, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(localFiles.getFilter(), context); + Optional filter = visitOptionalExpression(localFiles.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -188,8 +194,8 @@ public Optional visit(LocalFiles localFiles, EmptyVisitationContext context @Override public Optional visit(Project project, EmptyVisitationContext context) throws E { - var input = project.getInput().accept(this, context); - var expressions = visitExprList(project.getExpressions(), context); + Optional input = project.getInput().accept(this, context); + Optional> expressions = visitExprList(project.getExpressions(), context); if (allEmpty(input, expressions)) { return Optional.empty(); @@ -234,8 +240,9 @@ public Optional visit(NamedUpdate update, EmptyVisitationContext context) t @Override public Optional visit(Sort sort, EmptyVisitationContext context) throws E { - var input = sort.getInput().accept(this, context); - var sortFields = transformList(sort.getSortFields(), context, this::visitSortField); + Optional input = sort.getInput().accept(this, context); + Optional> sortFields = + transformList(sort.getSortFields(), context, this::visitSortField); if (allEmpty(input, sortFields)) { return Optional.empty(); @@ -250,8 +257,8 @@ public Optional visit(Sort sort, EmptyVisitationContext context) throws E { @Override public Optional visit(Cross cross, EmptyVisitationContext context) throws E { - var left = cross.getLeft().accept(this, context); - var right = cross.getRight().accept(this, context); + Optional left = cross.getLeft().accept(this, context); + Optional right = cross.getRight().accept(this, context); if (allEmpty(left, right)) { return Optional.empty(); @@ -267,7 +274,7 @@ public Optional visit(Cross cross, EmptyVisitationContext context) throws E @Override public Optional visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(virtualTableScan.getFilter(), context); + Optional filter = visitOptionalExpression(virtualTableScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -303,7 +310,7 @@ public Optional visit(ExtensionMulti extensionMulti, EmptyVisitationContext @Override public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(extensionTable.getFilter(), context); + Optional filter = visitOptionalExpression(extensionTable.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -317,11 +324,14 @@ public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext @Override public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) throws E { - var left = hashJoin.getLeft().accept(this, context); - var right = hashJoin.getRight().accept(this, context); - var leftKeys = transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference); - var rightKeys = transformList(hashJoin.getRightKeys(), context, this::visitFieldReference); - var postFilter = visitOptionalExpression(hashJoin.getPostJoinFilter(), context); + Optional left = hashJoin.getLeft().accept(this, context); + Optional right = hashJoin.getRight().accept(this, context); + Optional> leftKeys = + transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference); + Optional> rightKeys = + transformList(hashJoin.getRightKeys(), context, this::visitFieldReference); + Optional postFilter = + visitOptionalExpression(hashJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { return Optional.empty(); @@ -339,11 +349,14 @@ public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) th @Override public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws E { - var left = mergeJoin.getLeft().accept(this, context); - var right = mergeJoin.getRight().accept(this, context); - var leftKeys = transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference); - var rightKeys = transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference); - var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter(), context); + Optional left = mergeJoin.getLeft().accept(this, context); + Optional right = mergeJoin.getRight().accept(this, context); + Optional> leftKeys = + transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference); + Optional> rightKeys = + transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference); + Optional postFilter = + visitOptionalExpression(mergeJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { return Optional.empty(); @@ -362,9 +375,9 @@ public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) @Override public Optional visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) throws E { - var left = nestedLoopJoin.getLeft().accept(this, context); - var right = nestedLoopJoin.getRight().accept(this, context); - var condition = + Optional left = nestedLoopJoin.getLeft().accept(this, context); + Optional right = nestedLoopJoin.getRight().accept(this, context); + Optional condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(left, right, condition)) { @@ -383,15 +396,16 @@ public Optional visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext public Optional visit( ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) throws E { - var windowFunctions = + Optional> windowFunctions = transformList( consistentPartitionWindow.getWindowFunctions(), context, this::visitWindowRelFunction); - var partitionExpressions = + Optional> partitionExpressions = transformList( consistentPartitionWindow.getPartitionExpressions(), context, (t, c) -> t.accept(getExpressionCopyOnWriteVisitor(), c)); - var sorts = transformList(consistentPartitionWindow.getSorts(), context, this::visitSortField); + Optional> sorts = + transformList(consistentPartitionWindow.getSorts(), context, this::visitSortField); if (allEmpty(windowFunctions, partitionExpressions, sorts)) { return Optional.empty(); @@ -411,7 +425,8 @@ protected Optional visitW ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation, EmptyVisitationContext context) throws E { - var functionArgs = visitFunctionArguments(windowRelFunctionInvocation.arguments(), context); + Optional> functionArgs = + visitFunctionArguments(windowRelFunctionInvocation.arguments(), context); if (allEmpty(functionArgs)) { return Optional.empty(); @@ -433,7 +448,8 @@ protected Optional> visitExprList( public Optional visitFieldReference( FieldReference fieldReference, EmptyVisitationContext context) throws E { - var inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); + Optional inputExpression = + visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { return Optional.empty(); } @@ -447,12 +463,13 @@ protected Optional> visitFunctionArguments( funcArgs, context, (arg, c) -> { - if (arg instanceof Expression expr) { - return expr.accept(getExpressionCopyOnWriteVisitor(), c) + if (arg instanceof Expression) { + return ((Expression) arg) + .accept(getExpressionCopyOnWriteVisitor(), c) .flatMap(Optional::of); - } else { - return Optional.empty(); } + + return Optional.empty(); }); } diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index f75d2ab9d..54fdaa433 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ExpressionProtoConverter.BoundConverter; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan; import io.substrait.proto.AggregateFunction; import io.substrait.proto.AggregateRel; @@ -111,7 +112,7 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel @Override public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { - var builder = + AggregateRel.Builder builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) @@ -125,11 +126,13 @@ public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws Run } private AggregateRel.Measure toProto(Aggregate.Measure measure) { - var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - var args = measure.getFunction().arguments(); - var aggFuncDef = measure.getFunction().declaration(); + FunctionArg.FuncArgVisitor< + io.substrait.proto.FunctionArgument, EmptyVisitationContext, RuntimeException> + argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + List args = measure.getFunction().arguments(); + SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); - var func = + AggregateFunction.Builder func = AggregateFunction.newBuilder() .setPhase(measure.getFunction().aggregationPhase().toProto()) .setInvocation(measure.getFunction().invocation().toProto()) @@ -149,7 +152,7 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .map(ExpressionProtoConverter::from) .collect(Collectors.toList())); - var builder = AggregateRel.Measure.newBuilder().setMeasure(func); + AggregateRel.Measure.Builder builder = AggregateRel.Measure.newBuilder().setMeasure(func); measure.getPreMeasureFilter().ifPresent(f -> builder.setFilter(toProto(f))); return builder.build(); @@ -175,7 +178,7 @@ public Rel visit(EmptyScan emptyScan, EmptyVisitationContext context) throws Run @Override public Rel visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { - var builder = + FetchRel.Builder builder = FetchRel.newBuilder() .setCommon(common(fetch)) .setInput(toProto(fetch.getInput())) @@ -189,7 +192,7 @@ public Rel visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeExce @Override public Rel visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { - var builder = + FilterRel.Builder builder = FilterRel.newBuilder() .setCommon(common(filter)) .setInput(toProto(filter.getInput())) @@ -201,7 +204,7 @@ public Rel visit(Filter filter, EmptyVisitationContext context) throws RuntimeEx @Override public Rel visit(Join join, EmptyVisitationContext context) throws RuntimeException { - var builder = + JoinRel.Builder builder = JoinRel.newBuilder() .setCommon(common(join)) .setLeft(toProto(join.getLeft())) @@ -218,7 +221,8 @@ public Rel visit(Join join, EmptyVisitationContext context) throws RuntimeExcept @Override public Rel visit(Set set, EmptyVisitationContext context) throws RuntimeException { - var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); + SetRel.Builder builder = + SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); set.getInputs() .forEach( inputRel -> { @@ -231,7 +235,7 @@ public Rel visit(Set set, EmptyVisitationContext context) throws RuntimeExceptio @Override public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws RuntimeException { - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(namedScan)) .setNamedTable(ReadRel.NamedTable.newBuilder().addAllNames(namedScan.getNames())) @@ -246,7 +250,7 @@ public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws Run @Override public Rel visit(LocalFiles localFiles, EmptyVisitationContext context) throws RuntimeException { - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(localFiles)) .setLocalFiles( @@ -269,7 +273,7 @@ public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) throws RuntimeException { ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder().setDetail(extensionTable.getDetail().toProto(this)); - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(extensionTable)) .setBaseSchema(extensionTable.getInitialSchema().toProto(typeProtoConverter)) @@ -281,7 +285,7 @@ public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) @Override public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { - var builder = + HashJoinRel.Builder builder = HashJoinRel.newBuilder() .setCommon(common(hashJoin)) .setLeft(toProto(hashJoin.getLeft())) @@ -306,7 +310,7 @@ public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws Runti @Override public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws RuntimeException { - var builder = + MergeJoinRel.Builder builder = MergeJoinRel.newBuilder() .setCommon(common(mergeJoin)) .setLeft(toProto(mergeJoin.getLeft())) @@ -332,7 +336,7 @@ public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws Run @Override public Rel visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) throws RuntimeException { - var builder = + NestedLoopJoinRel.Builder builder = NestedLoopJoinRel.newBuilder() .setCommon(common(nestedLoopJoin)) .setLeft(toProto(nestedLoopJoin.getLeft())) @@ -348,7 +352,7 @@ public Rel visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) public Rel visit( ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) throws RuntimeException { - var builder = + ConsistentPartitionWindowRel.Builder builder = ConsistentPartitionWindowRel.newBuilder() .setCommon(common(consistentPartitionWindow)) .setInput(toProto(consistentPartitionWindow.getInput())) @@ -367,7 +371,7 @@ public Rel visit( @Override public Rel visit(NamedWrite write, EmptyVisitationContext context) throws RuntimeException { - var builder = + WriteRel.Builder builder = WriteRel.newBuilder() .setCommon(common(write)) .setInput(toProto(write.getInput())) @@ -382,7 +386,7 @@ public Rel visit(NamedWrite write, EmptyVisitationContext context) throws Runtim @Override public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws RuntimeException { - var builder = + WriteRel.Builder builder = WriteRel.newBuilder() .setCommon(common(write)) .setInput(toProto(write.getInput())) @@ -398,7 +402,7 @@ public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws Ru @Override public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeException { - var builder = + DdlRel.Builder builder = DdlRel.newBuilder() .setCommon(common(ddl)) .setTableSchema(ddl.getTableSchema().toProto(typeProtoConverter)) @@ -415,7 +419,7 @@ public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeExc @Override public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws RuntimeException { - var builder = + DdlRel.Builder builder = DdlRel.newBuilder() .setCommon(common(ddl)) .setTableSchema(ddl.getTableSchema().toProto(typeProtoConverter)) @@ -433,7 +437,7 @@ public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws Runtim @Override public Rel visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { - var builder = + UpdateRel.Builder builder = UpdateRel.newBuilder() .setNamedTable(NamedTable.newBuilder().addAllNames(update.getNames())) .setTableSchema(update.getTableSchema().toProto(typeProtoConverter)) @@ -460,11 +464,13 @@ private List toProtoWindowRelFun return windowRelFunctionInvocations.stream() .map( f -> { - var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - var args = f.arguments(); - var aggFuncDef = f.declaration(); + FunctionArg.FuncArgVisitor< + io.substrait.proto.FunctionArgument, EmptyVisitationContext, RuntimeException> + argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + List args = f.arguments(); + SimpleExtension.WindowFunctionVariant aggFuncDef = f.declaration(); - var arguments = + List arguments = IntStream.range(0, args.size()) .mapToObj( i -> @@ -472,7 +478,7 @@ private List toProtoWindowRelFun .accept( aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) .collect(Collectors.toList()); - var options = + List options = f.options().stream() .map(ExpressionProtoConverter::from) .collect(Collectors.toList()); @@ -494,7 +500,7 @@ private List toProtoWindowRelFun @Override public Rel visit(Project project, EmptyVisitationContext context) throws RuntimeException { - var builder = + ProjectRel.Builder builder = ProjectRel.newBuilder() .setCommon(common(project)) .setInput(toProto(project.getInput())) @@ -506,20 +512,22 @@ public Rel visit(Project project, EmptyVisitationContext context) throws Runtime @Override public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeException { - var builder = + ExpandRel.Builder builder = ExpandRel.newBuilder().setCommon(common(expand)).setInput(toProto(expand.getInput())); expand .getFields() .forEach( expandField -> { - if (expandField instanceof Expand.ConsistentField cf) { + if (expandField instanceof Expand.ConsistentField) { + Expand.ConsistentField cf = (Expand.ConsistentField) expandField; builder.addFields( ExpandRel.ExpandField.newBuilder() .setConsistentField(toProto(cf.getExpression())) .build()); - } else if (expandField instanceof Expand.SwitchingField sf) { + } else if (expandField instanceof Expand.SwitchingField) { + Expand.SwitchingField sf = (Expand.SwitchingField) expandField; builder.addFields( ExpandRel.ExpandField.newBuilder() .setSwitchingField( @@ -536,7 +544,7 @@ public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeEx @Override public Rel visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { - var builder = + SortRel.Builder builder = SortRel.newBuilder() .setCommon(common(sort)) .setInput(toProto(sort.getInput())) @@ -548,7 +556,7 @@ public Rel visit(Sort sort, EmptyVisitationContext context) throws RuntimeExcept @Override public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { - var builder = + CrossRel.Builder builder = CrossRel.newBuilder() .setCommon(common(cross)) .setLeft(toProto(cross.getLeft())) @@ -561,7 +569,7 @@ public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeExce @Override public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) throws RuntimeException { - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(virtualTableScan)) .setVirtualTable( @@ -584,7 +592,7 @@ public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext conte @Override public Rel visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) throws RuntimeException { - var builder = + ExtensionLeafRel.Builder builder = ExtensionLeafRel.newBuilder() .setCommon(common(extensionLeaf)) .setDetail(extensionLeaf.getDetail().toProto(this)); @@ -594,7 +602,7 @@ public Rel visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) @Override public Rel visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) throws RuntimeException { - var builder = + ExtensionSingleRel.Builder builder = ExtensionSingleRel.newBuilder() .setCommon(common(extensionSingle)) .setInput(toProto(extensionSingle.getInput())) @@ -607,7 +615,7 @@ public Rel visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) throws RuntimeException { List inputs = extensionMulti.getInputs().stream().map(this::toProto).collect(Collectors.toList()); - var builder = + ExtensionMultiRel.Builder builder = ExtensionMultiRel.newBuilder() .setCommon(common(extensionMulti)) .addAllInputs(inputs) @@ -616,11 +624,11 @@ public Rel visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) } private RelCommon common(io.substrait.relation.Rel rel) { - var builder = RelCommon.newBuilder(); + RelCommon.Builder builder = RelCommon.newBuilder(); rel.getCommonExtension() .ifPresent(extension -> builder.setAdvancedExtension(extension.toProto(this))); - var remap = rel.getRemap().orElse(null); + io.substrait.relation.Rel.Remap remap = rel.getRemap().orElse(null); if (remap != null) { builder.setEmit(RelCommon.Emit.newBuilder().addAllOutputMapping(remap.indices())); } else { diff --git a/core/src/main/java/io/substrait/relation/Set.java b/core/src/main/java/io/substrait/relation/Set.java index 697cda1be..dcf29ca6b 100644 --- a/core/src/main/java/io/substrait/relation/Set.java +++ b/core/src/main/java/io/substrait/relation/Set.java @@ -35,7 +35,7 @@ public SetRel.SetOp toProto() { } public static SetOp fromProto(SetRel.SetOp proto) { - for (var v : values()) { + for (SetOp v : values()) { if (v.proto == proto) { return v; } @@ -66,14 +66,24 @@ protected Type.Struct deriveRecordType() { } // As defined in https://substrait.io/relations/logical_relations/#set-operation-types - return switch (getSetOp()) { - case UNKNOWN -> first; // alternative would be to throw an exception - case MINUS_PRIMARY, MINUS_PRIMARY_ALL, MINUS_MULTISET -> first; - case INTERSECTION_PRIMARY -> coalesceNullabilityIntersectionPrimary(first, rest); - case INTERSECTION_MULTISET, INTERSECTION_MULTISET_ALL -> coalesceNullabilityIntersection( - first, rest); - case UNION_DISTINCT, UNION_ALL -> coalesceNullabilityUnion(first, rest); - }; + switch (getSetOp()) { + case UNKNOWN: + return first; // alternative would be to throw an exception + case MINUS_PRIMARY: + case MINUS_PRIMARY_ALL: + case MINUS_MULTISET: + return first; + case INTERSECTION_PRIMARY: + return coalesceNullabilityIntersectionPrimary(first, rest); + case INTERSECTION_MULTISET: + case INTERSECTION_MULTISET_ALL: + return coalesceNullabilityIntersection(first, rest); + case UNION_DISTINCT: + case UNION_ALL: + return coalesceNullabilityUnion(first, rest); + default: + throw new UnsupportedOperationException("Unexpected set operation: " + getSetOp()); + } } /** If field is nullable in any of the inputs, it's nullable in the output */ diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index 6eb7a361d..0b9f61e28 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -24,11 +24,11 @@ public abstract class VirtualTableScan extends AbstractReadRel { */ @Value.Check protected void check() { - var names = getInitialSchema().names(); + List names = getInitialSchema().names(); assert names.size() == NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct()); - var rows = getRows(); + List rows = getRows(); assert rows.size() > 0 && names.stream().noneMatch(s -> s == null) diff --git a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java index 4d0bb421f..4b227025f 100644 --- a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java +++ b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java @@ -36,29 +36,33 @@ default ReadRel.LocalFiles.FileOrFiles toProto() { getFileFormat() .ifPresent( fileFormat -> { - if (fileFormat instanceof FileFormat.ParquetReadOptions options) { + if (fileFormat instanceof FileFormat.ParquetReadOptions) { builder.setParquet( ReadRel.LocalFiles.FileOrFiles.ParquetReadOptions.newBuilder().build()); - } else if (fileFormat instanceof FileFormat.ArrowReadOptions options) { + } else if (fileFormat instanceof FileFormat.ArrowReadOptions) { builder.setArrow( ReadRel.LocalFiles.FileOrFiles.ArrowReadOptions.newBuilder().build()); - } else if (fileFormat instanceof FileFormat.OrcReadOptions options) { + } else if (fileFormat instanceof FileFormat.OrcReadOptions) { builder.setOrc(ReadRel.LocalFiles.FileOrFiles.OrcReadOptions.newBuilder().build()); - } else if (fileFormat instanceof FileFormat.DwrfReadOptions options) { + } else if (fileFormat instanceof FileFormat.DwrfReadOptions) { builder.setDwrf( ReadRel.LocalFiles.FileOrFiles.DwrfReadOptions.newBuilder().build()); - } else if (fileFormat - instanceof FileFormat.DelimiterSeparatedTextReadOptions options) { - var optionsBuilder = - ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.newBuilder() - .setFieldDelimiter(options.getFieldDelimiter()) - .setMaxLineSize(options.getMaxLineSize()) - .setQuote(options.getQuote()) - .setHeaderLinesToSkip(options.getHeaderLinesToSkip()) - .setEscape(options.getEscape()); + } else if (fileFormat instanceof FileFormat.DelimiterSeparatedTextReadOptions) { + FileFormat.DelimiterSeparatedTextReadOptions options = + (FileFormat.DelimiterSeparatedTextReadOptions) fileFormat; + ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.Builder + optionsBuilder = + ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions + .newBuilder() + .setFieldDelimiter(options.getFieldDelimiter()) + .setMaxLineSize(options.getMaxLineSize()) + .setQuote(options.getQuote()) + .setHeaderLinesToSkip(options.getHeaderLinesToSkip()) + .setEscape(options.getEscape()); options.getValueTreatedAsNull().ifPresent(optionsBuilder::setValueTreatedAsNull); builder.setText(optionsBuilder.build()); - } else if (fileFormat instanceof FileFormat.Extension options) { + } else if (fileFormat instanceof FileFormat.Extension) { + FileFormat.Extension options = (FileFormat.Extension) fileFormat; builder.setExtension(options.getExtension()); } else { throw new UnsupportedOperationException( @@ -72,11 +76,14 @@ default ReadRel.LocalFiles.FileOrFiles toProto() { getPath() .ifPresent( path -> { - switch (pathType) { - case URI_PATH -> builder.setUriPath(path); - case URI_PATH_GLOB -> builder.setUriPathGlob(path); - case URI_FILE -> builder.setUriFile(path); - case URI_FOLDER -> builder.setUriFolder(path); + if (pathType == PathType.URI_PATH) { + builder.setUriPath(path); + } else if (pathType == PathType.URI_PATH_GLOB) { + builder.setUriPathGlob(path); + } else if (pathType == PathType.URI_FILE) { + builder.setUriFile(path); + } else if (pathType == PathType.URI_FOLDER) { + builder.setUriFolder(path); } })); diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index 1d9cc3c54..246afa151 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -43,7 +43,7 @@ public static enum JoinType { } public static JoinType fromProto(HashJoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -58,23 +58,37 @@ public HashJoinRel.JoinType toProto() { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case RIGHT: + case OUTER: + return getLeft().getRecordType().fields().stream().map(TypeCreator::asNullable); + case RIGHT_ANTI: + case RIGHT_SEMI: + return Stream.empty(); + default: + return getLeft().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java index 4f7facd32..34a51487a 100644 --- a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java @@ -43,7 +43,7 @@ public static enum JoinType { } public static JoinType fromProto(MergeJoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -58,23 +58,37 @@ public MergeJoinRel.JoinType toProto() { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java index 233aff176..17af10df5 100644 --- a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java @@ -40,7 +40,7 @@ public NestedLoopJoinRel.JoinType toProto() { } public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -52,23 +52,37 @@ public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case RIGHT: + case OUTER: + return getLeft().getRecordType().fields().stream().map(TypeCreator::asNullable); + case RIGHT_ANTI: + case RIGHT_SEMI: + return Stream.empty(); + default: + return getLeft().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/type/Deserializers.java b/core/src/main/java/io/substrait/type/Deserializers.java index 8b0409085..e49b4936c 100644 --- a/core/src/main/java/io/substrait/type/Deserializers.java +++ b/core/src/main/java/io/substrait/type/Deserializers.java @@ -42,7 +42,7 @@ public ParseDeserializer( @Override public T deserialize(final JsonParser p, final DeserializationContext ctxt) throws IOException, JsonProcessingException { - var typeString = p.getValueAsString(); + String typeString = p.getValueAsString(); try { String namespace = (String) ctxt.findInjectableValue(SimpleExtension.URI_LOCATOR_KEY, null, null); diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 5e72a551a..9c241542c 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -20,7 +20,7 @@ static NamedStruct of(Iterable names, Type.Struct type) { } default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConverter) { - var type = struct().accept(typeProtoConverter); + io.substrait.proto.Type type = struct().accept(typeProtoConverter); return io.substrait.proto.NamedStruct.newBuilder() .setStruct(type.getStruct()) .addAllNames(names()) @@ -29,7 +29,7 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve static io.substrait.type.NamedStruct fromProto( io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { - var struct = namedStruct.getStruct(); + io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() .names(namedStruct.getNamesList()) .struct( diff --git a/core/src/main/java/io/substrait/type/YamlRead.java b/core/src/main/java/io/substrait/type/YamlRead.java index 89f8a9163..c4e7f4866 100644 --- a/core/src/main/java/io/substrait/type/YamlRead.java +++ b/core/src/main/java/io/substrait/type/YamlRead.java @@ -54,7 +54,8 @@ private static Stream parse(String name) { new ObjectMapper(new YAMLFactory()) .enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY) .registerModule(Deserializers.MODULE); - var doc = mapper.readValue(new File(name), SimpleExtension.ExtensionSignatures.class); + SimpleExtension.ExtensionSignatures doc = + mapper.readValue(new File(name), SimpleExtension.ExtensionSignatures.class); logger.atDebug().log( "Parsed {} functions in file {}.", diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index d27bd03e2..83b4ec8a4 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -1,5 +1,6 @@ package io.substrait.type.parser; +import io.substrait.function.ImmutableTypeExpression; import io.substrait.function.ParameterizedType; import io.substrait.function.ParameterizedTypeCreator; import io.substrait.function.TypeExpression; @@ -8,9 +9,11 @@ import io.substrait.type.SubstraitTypeVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.List; import java.util.Locale; import java.util.function.Function; import java.util.function.IntFunction; +import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ErrorNode; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.RuleNode; @@ -19,7 +22,7 @@ public class ParseToPojo { public static Type type(String namespace, SubstraitTypeParser.StartContext ctx) { - var visitor = Visitor.simple(namespace); + Visitor visitor = Visitor.simple(namespace); return (Type) ctx.accept(visitor); } @@ -160,12 +163,12 @@ public Type visitIntervalYear(final SubstraitTypeParser.IntervalYearContext ctx) public TypeExpression visitIntervalDay(final SubstraitTypeParser.IntervalDayContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).intervalDay(p); + if (precision instanceof Integer) { + return withNull(nullable).intervalDay((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).intervalDayE(s); + return withNullP(nullable).intervalDayE((String) precision); } checkExpression(); @@ -177,12 +180,12 @@ public TypeExpression visitIntervalCompound( final SubstraitTypeParser.IntervalCompoundContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).intervalCompound(p); + if (precision instanceof Integer) { + return withNull(nullable).intervalCompound((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).intervalCompoundE(s); + return withNullP(nullable).intervalCompoundE((String) precision); } checkExpression(); @@ -196,7 +199,7 @@ public Type visitUuid(final SubstraitTypeParser.UuidContext ctx) { @Override public Type visitUserDefined(SubstraitTypeParser.UserDefinedContext ctx) { - var name = ctx.Identifier().getSymbol().getText(); + String name = ctx.Identifier().getSymbol().getText(); return withNull(ctx).userDefined(namespace, name); } @@ -280,12 +283,12 @@ public TypeExpression visitPrecisionTimestamp( final SubstraitTypeParser.PrecisionTimestampContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).precisionTimestamp(p); + if (precision instanceof Integer) { + return withNull(nullable).precisionTimestamp((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).precisionTimestampE(s); + return withNullP(nullable).precisionTimestampE((String) precision); } checkExpression(); @@ -297,12 +300,12 @@ public TypeExpression visitPrecisionTimestampTZ( final SubstraitTypeParser.PrecisionTimestampTZContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).precisionTimestampTZ(p); + if (precision instanceof Integer) { + return withNull(nullable).precisionTimestampTZ((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).precisionTimestampTZE(s); + return withNullP(nullable).precisionTimestampTZE((String) precision); } checkExpression(); @@ -325,7 +328,7 @@ private Object i(SubstraitTypeParser.NumericParameterContext ctx) { @Override public TypeExpression visitStruct(final SubstraitTypeParser.StructContext ctx) { boolean nullable = ctx.isnull != null; - var types = + List types = ctx.expr().stream() .map(t -> t.accept(this)) .collect(java.util.stream.Collectors.toList()); @@ -456,17 +459,17 @@ public TypeExpression visitTernary(final SubstraitTypeParser.TernaryContext ctx) public TypeExpression visitMultilineDefinition( final SubstraitTypeParser.MultilineDefinitionContext ctx) { checkExpression(); - var exprs = + List exprs = ctx.expr().stream() .map(t -> t.accept(this)) .collect(java.util.stream.Collectors.toList()); - var identifiers = + List identifiers = ctx.Identifier().stream() .map(t -> t.getText()) .collect(java.util.stream.Collectors.toList()); - var finalExpr = ctx.finalType.accept(this); + TypeExpression finalExpr = ctx.finalType.accept(this); - var bldr = TypeExpression.ReturnProgram.builder(); + ImmutableTypeExpression.ReturnProgram.Builder bldr = TypeExpression.ReturnProgram.builder(); for (int i = 0; i < exprs.size(); i++) { bldr.addAssignments( TypeExpression.ReturnProgram.Assignment.builder() @@ -482,20 +485,7 @@ public TypeExpression visitMultilineDefinition( @Override public TypeExpression visitBinaryExpr(final SubstraitTypeParser.BinaryExprContext ctx) { checkExpression(); - TypeExpression.BinaryOperation.OpType type = - switch (ctx.op.getText().toUpperCase(Locale.ROOT)) { - case "+" -> TypeExpression.BinaryOperation.OpType.ADD; - case "-" -> TypeExpression.BinaryOperation.OpType.SUBTRACT; - case "*" -> TypeExpression.BinaryOperation.OpType.MULTIPLY; - case "/" -> TypeExpression.BinaryOperation.OpType.DIVIDE; - case ">" -> TypeExpression.BinaryOperation.OpType.GT; - case "<" -> TypeExpression.BinaryOperation.OpType.LT; - case "AND" -> TypeExpression.BinaryOperation.OpType.AND; - case "OR" -> TypeExpression.BinaryOperation.OpType.OR; - case "=" -> TypeExpression.BinaryOperation.OpType.EQ; - case ":=" -> TypeExpression.BinaryOperation.OpType.COVERS; - default -> throw new IllegalStateException("Unexpected value: " + ctx.op.getText()); - }; + TypeExpression.BinaryOperation.OpType type = getBinaryExpressionType(ctx.op); return TypeExpression.BinaryOperation.builder() .opType(type) .left(ctx.left.accept(this)) @@ -503,6 +493,33 @@ public TypeExpression visitBinaryExpr(final SubstraitTypeParser.BinaryExprContex .build(); } + private TypeExpression.BinaryOperation.OpType getBinaryExpressionType(Token token) { + switch (token.getText().toUpperCase(Locale.ROOT)) { + case "+": + return TypeExpression.BinaryOperation.OpType.ADD; + case "-": + return TypeExpression.BinaryOperation.OpType.SUBTRACT; + case "*": + return TypeExpression.BinaryOperation.OpType.MULTIPLY; + case "/": + return TypeExpression.BinaryOperation.OpType.DIVIDE; + case ">": + return TypeExpression.BinaryOperation.OpType.GT; + case "<": + return TypeExpression.BinaryOperation.OpType.LT; + case "AND": + return TypeExpression.BinaryOperation.OpType.AND; + case "OR": + return TypeExpression.BinaryOperation.OpType.OR; + case "=": + return TypeExpression.BinaryOperation.OpType.EQ; + case ":=": + return TypeExpression.BinaryOperation.OpType.COVERS; + default: + throw new IllegalStateException("Unexpected value: " + token.getText()); + } + } + @Override public TypeExpression visitNumericLiteral(final SubstraitTypeParser.NumericLiteralContext ctx) { return TypeExpression.IntegerLiteral.builder().value(Integer.parseInt(ctx.getText())).build(); @@ -533,14 +550,7 @@ public TypeExpression visitFunctionCall(final SubstraitTypeParser.FunctionCallCo if (ctx.expr().size() != 2) { throw new IllegalStateException("Only two argument functions exist for type expressions."); } - var name = ctx.Identifier().getSymbol().getText().toUpperCase(Locale.ROOT); - TypeExpression.BinaryOperation.OpType type = - switch (name) { - case "MIN" -> TypeExpression.BinaryOperation.OpType.MIN; - case "MAX" -> TypeExpression.BinaryOperation.OpType.MAX; - default -> throw new IllegalStateException( - "The following operation was unrecognized: " + name); - }; + TypeExpression.BinaryOperation.OpType type = getFunctionType(ctx.Identifier().getSymbol()); return TypeExpression.BinaryOperation.builder() .opType(type) .left(ctx.expr(0).accept(this)) @@ -548,6 +558,18 @@ public TypeExpression visitFunctionCall(final SubstraitTypeParser.FunctionCallCo .build(); } + private TypeExpression.BinaryOperation.OpType getFunctionType(Token token) { + switch (token.getText().toUpperCase(Locale.ROOT)) { + case "MIN": + return TypeExpression.BinaryOperation.OpType.MIN; + case "MAX": + return TypeExpression.BinaryOperation.OpType.MAX; + default: + throw new IllegalStateException( + "The following operation was unrecognized: " + token.getText()); + } + } + @Override public TypeExpression visitNotExpr(final SubstraitTypeParser.NotExprContext ctx) { return TypeExpression.NotOperation.builder().inner(ctx.expr().accept(this)).build(); diff --git a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java index df45fca8e..b22c68265 100644 --- a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java +++ b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java @@ -30,11 +30,11 @@ public static TypeExpression parseExpression(String str, String namespace) { } private static SubstraitTypeParser.StartContext parse(String str) { - var lexer = new SubstraitTypeLexer(CharStreams.fromString(str)); + SubstraitTypeLexer lexer = new SubstraitTypeLexer(CharStreams.fromString(str)); lexer.removeErrorListeners(); lexer.addErrorListener(TypeErrorListener.INSTANCE); - var tokenStream = new CommonTokenStream(lexer); - var parser = new io.substrait.type.SubstraitTypeParser(tokenStream); + CommonTokenStream tokenStream = new CommonTokenStream(lexer); + SubstraitTypeParser parser = new io.substrait.type.SubstraitTypeParser(tokenStream); parser.removeErrorListeners(); parser.addErrorListener(TypeErrorListener.INSTANCE); return parser.start(); diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 95734321e..8a37a2928 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -165,7 +165,7 @@ public final T visit(final Type.Map expr) { @Override public final T visit(final Type.UserDefined expr) { - var ref = + int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); return typeContainer(expr).userDefined(ref); } diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index baf5c411c..293b25c90 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -264,59 +264,62 @@ public ParameterizedType userDefined(int ref) { @Override protected ParameterizedType wrap(final Object o) { - var bldr = ParameterizedType.newBuilder(); - if (o instanceof Type.Boolean t) { - return bldr.setBool(t).build(); - } else if (o instanceof Type.I8 t) { - return bldr.setI8(t).build(); - } else if (o instanceof Type.I16 t) { - return bldr.setI16(t).build(); - } else if (o instanceof Type.I32 t) { - return bldr.setI32(t).build(); - } else if (o instanceof Type.I64 t) { - return bldr.setI64(t).build(); - } else if (o instanceof Type.FP32 t) { - return bldr.setFp32(t).build(); - } else if (o instanceof Type.FP64 t) { - return bldr.setFp64(t).build(); - } else if (o instanceof Type.String t) { - return bldr.setString(t).build(); - } else if (o instanceof Type.Binary t) { - return bldr.setBinary(t).build(); - } else if (o instanceof Type.Timestamp t) { - return bldr.setTimestamp(t).build(); - } else if (o instanceof Type.Date t) { - return bldr.setDate(t).build(); - } else if (o instanceof Type.Time t) { - return bldr.setTime(t).build(); - } else if (o instanceof Type.TimestampTZ t) { - return bldr.setTimestampTz(t).build(); - } else if (o instanceof Type.IntervalYear t) { - return bldr.setIntervalYear(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedIntervalDay t) { - return bldr.setIntervalDay(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedIntervalCompound t) { - return bldr.setIntervalCompound(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedFixedChar t) { - return bldr.setFixedChar(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedVarChar t) { - return bldr.setVarchar(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedFixedBinary t) { - return bldr.setFixedBinary(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedDecimal t) { - return bldr.setDecimal(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestamp t) { - return bldr.setPrecisionTimestamp(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestampTZ t) { - return bldr.setPrecisionTimestampTz(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedStruct t) { - return bldr.setStruct(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedList t) { - return bldr.setList(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedMap t) { - return bldr.setMap(t).build(); - } else if (o instanceof Type.UUID t) { - return bldr.setUuid(t).build(); + ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); + if (o instanceof Type.Boolean) { + return bldr.setBool((Type.Boolean) o).build(); + } else if (o instanceof Type.I8) { + return bldr.setI8((Type.I8) o).build(); + } else if (o instanceof Type.I16) { + return bldr.setI16((Type.I16) o).build(); + } else if (o instanceof Type.I32) { + return bldr.setI32((Type.I32) o).build(); + } else if (o instanceof Type.I64) { + return bldr.setI64((Type.I64) o).build(); + } else if (o instanceof Type.FP32) { + return bldr.setFp32((Type.FP32) o).build(); + } else if (o instanceof Type.FP64) { + return bldr.setFp64((Type.FP64) o).build(); + } else if (o instanceof Type.String) { + return bldr.setString((Type.String) o).build(); + } else if (o instanceof Type.Binary) { + return bldr.setBinary((Type.Binary) o).build(); + } else if (o instanceof Type.Timestamp) { + return bldr.setTimestamp((Type.Timestamp) o).build(); + } else if (o instanceof Type.Date) { + return bldr.setDate((Type.Date) o).build(); + } else if (o instanceof Type.Time) { + return bldr.setTime((Type.Time) o).build(); + } else if (o instanceof Type.TimestampTZ) { + return bldr.setTimestampTz((Type.TimestampTZ) o).build(); + } else if (o instanceof Type.IntervalYear) { + return bldr.setIntervalYear((Type.IntervalYear) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedIntervalDay) { + return bldr.setIntervalDay((ParameterizedType.ParameterizedIntervalDay) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedIntervalCompound) { + return bldr.setIntervalCompound((ParameterizedType.ParameterizedIntervalCompound) o) + .build(); + } else if (o instanceof ParameterizedType.ParameterizedFixedChar) { + return bldr.setFixedChar((ParameterizedType.ParameterizedFixedChar) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedVarChar) { + return bldr.setVarchar((ParameterizedType.ParameterizedVarChar) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedFixedBinary) { + return bldr.setFixedBinary((ParameterizedType.ParameterizedFixedBinary) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedDecimal) { + return bldr.setDecimal((ParameterizedType.ParameterizedDecimal) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestamp) { + return bldr.setPrecisionTimestamp((ParameterizedType.ParameterizedPrecisionTimestamp) o) + .build(); + } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestampTZ) { + return bldr.setPrecisionTimestampTz((ParameterizedType.ParameterizedPrecisionTimestampTZ) o) + .build(); + } else if (o instanceof ParameterizedType.ParameterizedStruct) { + return bldr.setStruct((ParameterizedType.ParameterizedStruct) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedList) { + return bldr.setList((ParameterizedType.ParameterizedList) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedMap) { + return bldr.setMap((ParameterizedType.ParameterizedMap) o).build(); + } else if (o instanceof Type.UUID) { + return bldr.setUuid((Type.UUID) o).build(); } throw new UnsupportedOperationException("Unable to wrap type of " + o.getClass()); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index e4e33bacc..661f57fea 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -18,57 +18,87 @@ public ProtoTypeConverter( } public Type from(io.substrait.proto.Type type) { - return switch (type.getKindCase()) { - case BOOL -> n(type.getBool().getNullability()).BOOLEAN; - case I8 -> n(type.getI8().getNullability()).I8; - case I16 -> n(type.getI16().getNullability()).I16; - case I32 -> n(type.getI32().getNullability()).I32; - case I64 -> n(type.getI64().getNullability()).I64; - case FP32 -> n(type.getFp32().getNullability()).FP32; - case FP64 -> n(type.getFp64().getNullability()).FP64; - case STRING -> n(type.getString().getNullability()).STRING; - case BINARY -> n(type.getBinary().getNullability()).BINARY; - case TIMESTAMP -> n(type.getTimestamp().getNullability()).TIMESTAMP; - case DATE -> n(type.getDate().getNullability()).DATE; - case TIME -> n(type.getTime().getNullability()).TIME; - case INTERVAL_YEAR -> n(type.getIntervalYear().getNullability()).INTERVAL_YEAR; - case INTERVAL_DAY -> n(type.getIntervalDay().getNullability()) - // precision defaults to 6 (micros) for backwards compatibility, see protobuf - .intervalDay( - type.getIntervalDay().hasPrecision() ? type.getIntervalDay().getPrecision() : 6); - case INTERVAL_COMPOUND -> n(type.getIntervalCompound().getNullability()) - .intervalCompound(type.getIntervalCompound().getPrecision()); - case TIMESTAMP_TZ -> n(type.getTimestampTz().getNullability()).TIMESTAMP_TZ; - case UUID -> n(type.getUuid().getNullability()).UUID; - case FIXED_CHAR -> n(type.getFixedChar().getNullability()) - .fixedChar(type.getFixedChar().getLength()); - case VARCHAR -> n(type.getVarchar().getNullability()).varChar(type.getVarchar().getLength()); - case FIXED_BINARY -> n(type.getFixedBinary().getNullability()) - .fixedBinary(type.getFixedBinary().getLength()); - case DECIMAL -> n(type.getDecimal().getNullability()) - .decimal(type.getDecimal().getPrecision(), type.getDecimal().getScale()); - case PRECISION_TIME -> n(type.getPrecisionTime().getNullability()) - .precisionTime(type.getPrecisionTime().getPrecision()); - case PRECISION_TIMESTAMP -> n(type.getPrecisionTimestamp().getNullability()) - .precisionTimestamp(type.getPrecisionTimestamp().getPrecision()); - case PRECISION_TIMESTAMP_TZ -> n(type.getPrecisionTimestampTz().getNullability()) - .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); - case STRUCT -> n(type.getStruct().getNullability()) - .struct( - type.getStruct().getTypesList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())); - case LIST -> fromList(type.getList()); - case MAP -> fromMap(type.getMap()); - case USER_DEFINED -> { - var userDefined = type.getUserDefined(); - var t = lookup.getType(userDefined.getTypeReference(), extensions); - yield n(userDefined.getNullability()).userDefined(t.uri(), t.name()); - } - case USER_DEFINED_TYPE_REFERENCE -> throw new UnsupportedOperationException( - "Unsupported user defined reference: " + type); - case KIND_NOT_SET -> throw new UnsupportedOperationException("Type is not set: " + type); - }; + switch (type.getKindCase()) { + case BOOL: + return n(type.getBool().getNullability()).BOOLEAN; + case I8: + return n(type.getI8().getNullability()).I8; + case I16: + return n(type.getI16().getNullability()).I16; + case I32: + return n(type.getI32().getNullability()).I32; + case I64: + return n(type.getI64().getNullability()).I64; + case FP32: + return n(type.getFp32().getNullability()).FP32; + case FP64: + return n(type.getFp64().getNullability()).FP64; + case STRING: + return n(type.getString().getNullability()).STRING; + case BINARY: + return n(type.getBinary().getNullability()).BINARY; + case TIMESTAMP: + return n(type.getTimestamp().getNullability()).TIMESTAMP; + case DATE: + return n(type.getDate().getNullability()).DATE; + case TIME: + return n(type.getTime().getNullability()).TIME; + case INTERVAL_YEAR: + return n(type.getIntervalYear().getNullability()).INTERVAL_YEAR; + case INTERVAL_DAY: + return n(type.getIntervalDay().getNullability()) + // precision defaults to 6 (micros) for backwards compatibility, see protobuf + .intervalDay( + type.getIntervalDay().hasPrecision() ? type.getIntervalDay().getPrecision() : 6); + case INTERVAL_COMPOUND: + return n(type.getIntervalCompound().getNullability()) + .intervalCompound(type.getIntervalCompound().getPrecision()); + case TIMESTAMP_TZ: + return n(type.getTimestampTz().getNullability()).TIMESTAMP_TZ; + case UUID: + return n(type.getUuid().getNullability()).UUID; + case FIXED_CHAR: + return n(type.getFixedChar().getNullability()).fixedChar(type.getFixedChar().getLength()); + case VARCHAR: + return n(type.getVarchar().getNullability()).varChar(type.getVarchar().getLength()); + case FIXED_BINARY: + return n(type.getFixedBinary().getNullability()) + .fixedBinary(type.getFixedBinary().getLength()); + case DECIMAL: + return n(type.getDecimal().getNullability()) + .decimal(type.getDecimal().getPrecision(), type.getDecimal().getScale()); + case PRECISION_TIME: + return n(type.getPrecisionTime().getNullability()) + .precisionTime(type.getPrecisionTime().getPrecision()); + case PRECISION_TIMESTAMP: + return n(type.getPrecisionTimestamp().getNullability()) + .precisionTimestamp(type.getPrecisionTimestamp().getPrecision()); + case PRECISION_TIMESTAMP_TZ: + return n(type.getPrecisionTimestampTz().getNullability()) + .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); + case STRUCT: + return n(type.getStruct().getNullability()) + .struct( + type.getStruct().getTypesList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList())); + case LIST: + return fromList(type.getList()); + case MAP: + return fromMap(type.getMap()); + case USER_DEFINED: + { + io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); + SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); + return n(userDefined.getNullability()).userDefined(t.uri(), t.name()); + } + case USER_DEFINED_TYPE_REFERENCE: + throw new UnsupportedOperationException("Unsupported user defined reference: " + type); + case KIND_NOT_SET: + throw new UnsupportedOperationException("Type is not set: " + type); + default: + throw new UnsupportedOperationException("Unsupported type: " + type.getKindCase()); + } } public Type.ListType fromList(io.substrait.proto.Type.List list) { diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 1a7dcbc4f..44997bf2a 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -4,7 +4,9 @@ import io.substrait.function.ParameterizedType; import io.substrait.function.TypeExpression; import io.substrait.proto.DerivationExpression; +import io.substrait.proto.DerivationExpression.ReturnProgram.Assignment; import io.substrait.proto.Type; +import java.util.List; public class TypeExpressionProtoVisitor extends BaseProtoConverter { @@ -28,21 +30,7 @@ public BaseProtoTypes typeContainer( @Override public DerivationExpression visit(final TypeExpression.BinaryOperation expr) { - var opType = - switch (expr.opType()) { - case ADD -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_PLUS; - case SUBTRACT -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; - case MIN -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MIN; - case MAX -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MAX; - case LT -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; - // case LTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; - case GT -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_GREATER_THAN; - // case GTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; - // case NOT_EQ -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQ; - case EQ -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQUALS; - case COVERS -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_COVERS; - default -> throw new IllegalStateException("Unexpected value: " + expr.opType()); - }; + DerivationExpression.BinaryOp.BinaryOpType opType = getDerivationOpType(expr.opType()); return DerivationExpression.newBuilder() .setBinaryOp( DerivationExpression.BinaryOp.newBuilder() @@ -53,6 +41,33 @@ public DerivationExpression visit(final TypeExpression.BinaryOperation expr) { .build(); } + private DerivationExpression.BinaryOp.BinaryOpType getDerivationOpType( + TypeExpression.BinaryOperation.OpType type) { + switch (type) { + case ADD: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_PLUS; + case SUBTRACT: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; + case MIN: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MIN; + case MAX: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MAX; + case LT: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; + // case LTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; + case GT: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_GREATER_THAN; + // case GTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; + // case NOT_EQ -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQ; + case EQ: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQUALS; + case COVERS: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_COVERS; + default: + throw new IllegalStateException("Unexpected value: " + type); + } + } + @Override public DerivationExpression visit(final TypeExpression.NotOperation expr) { return DerivationExpression.newBuilder() @@ -82,7 +97,7 @@ public DerivationExpression visit(final TypeExpression.IntegerLiteral expr) { @Override public DerivationExpression visit(final TypeExpression.ReturnProgram expr) { - var assignments = + List assignments = expr.assignments().stream() .map( a -> @@ -91,7 +106,7 @@ public DerivationExpression visit(final TypeExpression.ReturnProgram expr) { .setExpression(a.expr().accept(this)) .build()) .collect(java.util.stream.Collectors.toList()); - var finalExpr = expr.finalExpression().accept(this); + DerivationExpression finalExpr = expr.finalExpression().accept(this); return DerivationExpression.newBuilder() .setReturnProgram( DerivationExpression.ReturnProgram.newBuilder() @@ -337,59 +352,62 @@ public DerivationExpression userDefined(int ref) { @Override protected DerivationExpression wrap(final Object o) { - var bldr = DerivationExpression.newBuilder(); - if (o instanceof Type.Boolean t) { - return bldr.setBool(t).build(); - } else if (o instanceof Type.I8 t) { - return bldr.setI8(t).build(); - } else if (o instanceof Type.I16 t) { - return bldr.setI16(t).build(); - } else if (o instanceof Type.I32 t) { - return bldr.setI32(t).build(); - } else if (o instanceof Type.I64 t) { - return bldr.setI64(t).build(); - } else if (o instanceof Type.FP32 t) { - return bldr.setFp32(t).build(); - } else if (o instanceof Type.FP64 t) { - return bldr.setFp64(t).build(); - } else if (o instanceof Type.String t) { - return bldr.setString(t).build(); - } else if (o instanceof Type.Binary t) { - return bldr.setBinary(t).build(); - } else if (o instanceof Type.Timestamp t) { - return bldr.setTimestamp(t).build(); - } else if (o instanceof Type.Date t) { - return bldr.setDate(t).build(); - } else if (o instanceof Type.Time t) { - return bldr.setTime(t).build(); - } else if (o instanceof Type.TimestampTZ t) { - return bldr.setTimestampTz(t).build(); - } else if (o instanceof Type.IntervalYear t) { - return bldr.setIntervalYear(t).build(); - } else if (o instanceof DerivationExpression.ExpressionIntervalDay t) { - return bldr.setIntervalDay(t).build(); - } else if (o instanceof DerivationExpression.ExpressionIntervalCompound t) { - return bldr.setIntervalCompound(t).build(); - } else if (o instanceof DerivationExpression.ExpressionFixedChar t) { - return bldr.setFixedChar(t).build(); - } else if (o instanceof DerivationExpression.ExpressionVarChar t) { - return bldr.setVarchar(t).build(); - } else if (o instanceof DerivationExpression.ExpressionFixedBinary t) { - return bldr.setFixedBinary(t).build(); - } else if (o instanceof DerivationExpression.ExpressionDecimal t) { - return bldr.setDecimal(t).build(); - } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestamp t) { - return bldr.setPrecisionTimestamp(t).build(); - } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestampTZ t) { - return bldr.setPrecisionTimestampTz(t).build(); - } else if (o instanceof DerivationExpression.ExpressionStruct t) { - return bldr.setStruct(t).build(); - } else if (o instanceof DerivationExpression.ExpressionList t) { - return bldr.setList(t).build(); - } else if (o instanceof DerivationExpression.ExpressionMap t) { - return bldr.setMap(t).build(); - } else if (o instanceof Type.UUID t) { - return bldr.setUuid(t).build(); + DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); + if (o instanceof Type.Boolean) { + return bldr.setBool((Type.Boolean) o).build(); + } else if (o instanceof Type.I8) { + return bldr.setI8((Type.I8) o).build(); + } else if (o instanceof Type.I16) { + return bldr.setI16((Type.I16) o).build(); + } else if (o instanceof Type.I32) { + return bldr.setI32((Type.I32) o).build(); + } else if (o instanceof Type.I64) { + return bldr.setI64((Type.I64) o).build(); + } else if (o instanceof Type.FP32) { + return bldr.setFp32((Type.FP32) o).build(); + } else if (o instanceof Type.FP64) { + return bldr.setFp64((Type.FP64) o).build(); + } else if (o instanceof Type.String) { + return bldr.setString((Type.String) o).build(); + } else if (o instanceof Type.Binary) { + return bldr.setBinary((Type.Binary) o).build(); + } else if (o instanceof Type.Timestamp) { + return bldr.setTimestamp((Type.Timestamp) o).build(); + } else if (o instanceof Type.Date) { + return bldr.setDate((Type.Date) o).build(); + } else if (o instanceof Type.Time) { + return bldr.setTime((Type.Time) o).build(); + } else if (o instanceof Type.TimestampTZ) { + return bldr.setTimestampTz((Type.TimestampTZ) o).build(); + } else if (o instanceof Type.IntervalYear) { + return bldr.setIntervalYear((Type.IntervalYear) o).build(); + } else if (o instanceof DerivationExpression.ExpressionIntervalDay) { + return bldr.setIntervalDay((DerivationExpression.ExpressionIntervalDay) o).build(); + } else if (o instanceof DerivationExpression.ExpressionIntervalCompound) { + return bldr.setIntervalCompound((DerivationExpression.ExpressionIntervalCompound) o) + .build(); + } else if (o instanceof DerivationExpression.ExpressionFixedChar) { + return bldr.setFixedChar((DerivationExpression.ExpressionFixedChar) o).build(); + } else if (o instanceof DerivationExpression.ExpressionVarChar) { + return bldr.setVarchar((DerivationExpression.ExpressionVarChar) o).build(); + } else if (o instanceof DerivationExpression.ExpressionFixedBinary) { + return bldr.setFixedBinary((DerivationExpression.ExpressionFixedBinary) o).build(); + } else if (o instanceof DerivationExpression.ExpressionDecimal) { + return bldr.setDecimal((DerivationExpression.ExpressionDecimal) o).build(); + } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestamp) { + return bldr.setPrecisionTimestamp((DerivationExpression.ExpressionPrecisionTimestamp) o) + .build(); + } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestampTZ) { + return bldr.setPrecisionTimestampTz((DerivationExpression.ExpressionPrecisionTimestampTZ) o) + .build(); + } else if (o instanceof DerivationExpression.ExpressionStruct) { + return bldr.setStruct((DerivationExpression.ExpressionStruct) o).build(); + } else if (o instanceof DerivationExpression.ExpressionList) { + return bldr.setList((DerivationExpression.ExpressionList) o).build(); + } else if (o instanceof DerivationExpression.ExpressionMap) { + return bldr.setMap((DerivationExpression.ExpressionMap) o).build(); + } else if (o instanceof Type.UUID) { + return bldr.setUuid((Type.UUID) o).build(); } throw new UnsupportedOperationException("Unable to wrap type of " + o.getClass()); } diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 626b77166..5162b1c78 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -124,61 +124,61 @@ public Type userDefined(int ref) { @Override protected Type wrap(final Object o) { - var bldr = Type.newBuilder(); - if (o instanceof Type.Boolean t) { - return bldr.setBool(t).build(); - } else if (o instanceof Type.I8 t) { - return bldr.setI8(t).build(); - } else if (o instanceof Type.I16 t) { - return bldr.setI16(t).build(); - } else if (o instanceof Type.I32 t) { - return bldr.setI32(t).build(); - } else if (o instanceof Type.I64 t) { - return bldr.setI64(t).build(); - } else if (o instanceof Type.FP32 t) { - return bldr.setFp32(t).build(); - } else if (o instanceof Type.FP64 t) { - return bldr.setFp64(t).build(); - } else if (o instanceof Type.String t) { - return bldr.setString(t).build(); - } else if (o instanceof Type.Binary t) { - return bldr.setBinary(t).build(); - } else if (o instanceof Type.Timestamp t) { - return bldr.setTimestamp(t).build(); - } else if (o instanceof Type.Date t) { - return bldr.setDate(t).build(); - } else if (o instanceof Type.Time t) { - return bldr.setTime(t).build(); - } else if (o instanceof Type.TimestampTZ t) { - return bldr.setTimestampTz(t).build(); - } else if (o instanceof Type.IntervalYear t) { - return bldr.setIntervalYear(t).build(); - } else if (o instanceof Type.IntervalDay t) { - return bldr.setIntervalDay(t).build(); - } else if (o instanceof Type.IntervalCompound t) { - return bldr.setIntervalCompound(t).build(); - } else if (o instanceof Type.FixedChar t) { - return bldr.setFixedChar(t).build(); - } else if (o instanceof Type.VarChar t) { - return bldr.setVarchar(t).build(); - } else if (o instanceof Type.FixedBinary t) { - return bldr.setFixedBinary(t).build(); - } else if (o instanceof Type.Decimal t) { - return bldr.setDecimal(t).build(); - } else if (o instanceof Type.PrecisionTimestamp t) { - return bldr.setPrecisionTimestamp(t).build(); - } else if (o instanceof Type.PrecisionTimestampTZ t) { - return bldr.setPrecisionTimestampTz(t).build(); - } else if (o instanceof Type.Struct t) { - return bldr.setStruct(t).build(); - } else if (o instanceof Type.List t) { - return bldr.setList(t).build(); - } else if (o instanceof Type.Map t) { - return bldr.setMap(t).build(); - } else if (o instanceof Type.UUID t) { - return bldr.setUuid(t).build(); - } else if (o instanceof Type.UserDefined t) { - return bldr.setUserDefined(t).build(); + Type.Builder bldr = Type.newBuilder(); + if (o instanceof Type.Boolean) { + return bldr.setBool((Type.Boolean) o).build(); + } else if (o instanceof Type.I8) { + return bldr.setI8((Type.I8) o).build(); + } else if (o instanceof Type.I16) { + return bldr.setI16((Type.I16) o).build(); + } else if (o instanceof Type.I32) { + return bldr.setI32((Type.I32) o).build(); + } else if (o instanceof Type.I64) { + return bldr.setI64((Type.I64) o).build(); + } else if (o instanceof Type.FP32) { + return bldr.setFp32((Type.FP32) o).build(); + } else if (o instanceof Type.FP64) { + return bldr.setFp64((Type.FP64) o).build(); + } else if (o instanceof Type.String) { + return bldr.setString((Type.String) o).build(); + } else if (o instanceof Type.Binary) { + return bldr.setBinary((Type.Binary) o).build(); + } else if (o instanceof Type.Timestamp) { + return bldr.setTimestamp((Type.Timestamp) o).build(); + } else if (o instanceof Type.Date) { + return bldr.setDate((Type.Date) o).build(); + } else if (o instanceof Type.Time) { + return bldr.setTime((Type.Time) o).build(); + } else if (o instanceof Type.TimestampTZ) { + return bldr.setTimestampTz((Type.TimestampTZ) o).build(); + } else if (o instanceof Type.IntervalYear) { + return bldr.setIntervalYear((Type.IntervalYear) o).build(); + } else if (o instanceof Type.IntervalDay) { + return bldr.setIntervalDay((Type.IntervalDay) o).build(); + } else if (o instanceof Type.IntervalCompound) { + return bldr.setIntervalCompound((Type.IntervalCompound) o).build(); + } else if (o instanceof Type.FixedChar) { + return bldr.setFixedChar((Type.FixedChar) o).build(); + } else if (o instanceof Type.VarChar) { + return bldr.setVarchar((Type.VarChar) o).build(); + } else if (o instanceof Type.FixedBinary) { + return bldr.setFixedBinary((Type.FixedBinary) o).build(); + } else if (o instanceof Type.Decimal) { + return bldr.setDecimal((Type.Decimal) o).build(); + } else if (o instanceof Type.PrecisionTimestamp) { + return bldr.setPrecisionTimestamp((Type.PrecisionTimestamp) o).build(); + } else if (o instanceof Type.PrecisionTimestampTZ) { + return bldr.setPrecisionTimestampTz((Type.PrecisionTimestampTZ) o).build(); + } else if (o instanceof Type.Struct) { + return bldr.setStruct((Type.Struct) o).build(); + } else if (o instanceof Type.List) { + return bldr.setList((Type.List) o).build(); + } else if (o instanceof Type.Map) { + return bldr.setMap((Type.Map) o).build(); + } else if (o instanceof Type.UUID) { + return bldr.setUuid((Type.UUID) o).build(); + } else if (o instanceof Type.UserDefined) { + return bldr.setUserDefined((Type.UserDefined) o).build(); } throw new UnsupportedOperationException("Unable to wrap type of " + o.getClass()); } diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index f19bbef5b..cd2522090 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -73,8 +73,8 @@ void roundtripCustomType() { .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - var protoPlan = planProtoConverter.toProto(plan); - var planReturned = protoPlanConverter.from(protoPlan); + io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); + Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } @@ -99,8 +99,8 @@ void roundtripNumberedAnyTypes() { b.fieldReference(input, 0))) .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - var protoPlan = planProtoConverter.toProto(plan); - var planReturned = protoPlanConverter.from(protoPlan); + io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); + Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } } diff --git a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java index 9a84f137b..b8dd03465 100644 --- a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java +++ b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java @@ -25,8 +25,8 @@ public class ProtoRelConverterTest extends TestBase { @Nested class DefaultAdvancedExtensionTests { - static final StringHolder ENHANCED = new StringHolder("ENHANCED"); - static final StringHolder OPTIMIZED = new StringHolder("OPTIMIZED"); + final StringHolder ENHANCED = new StringHolder("ENHANCED"); + final StringHolder OPTIMIZED = new StringHolder("OPTIMIZED"); Rel relWithExtension(AdvancedExtension advancedExtension) { return NamedScan.builder() diff --git a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java index abbcb24da..8a8118791 100644 --- a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java @@ -22,20 +22,21 @@ public class AggregateRoundtripTest extends TestBase { private void assertAggregateRoundtrip(Expression.AggregationInvocation invocation) { - var expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); + Expression.DecimalLiteral expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); Expression.StructLiteral literal = Expression.StructLiteral.builder().addFields(expression).build(); - var input = + io.substrait.relation.ImmutableVirtualTableScan input = VirtualTableScan.builder() .initialSchema(NamedStruct.of(Arrays.asList("decimal"), R.struct(R.decimal(10, 2)))) .addRows(literal) .build(); ExtensionCollector functionCollector = new ExtensionCollector(); - var to = new RelProtoConverter(functionCollector); - var extensions = defaultExtensionCollection; - var from = new ProtoRelConverter(functionCollector, extensions); + RelProtoConverter to = new RelProtoConverter(functionCollector); + io.substrait.extension.SimpleExtension.ExtensionCollection extensions = + defaultExtensionCollection; + ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions); - var measure = + io.substrait.relation.ImmutableMeasure measure = Aggregate.Measure.builder() .function( AggregateFunctionInvocation.builder() @@ -60,8 +61,9 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio .build()) .build(); - var aggRel = Aggregate.builder().input(input).measures(Arrays.asList(measure)).build(); - var protoAggRel = to.toProto(aggRel); + io.substrait.relation.ImmutableAggregate aggRel = + Aggregate.builder().input(input).measures(Arrays.asList(measure)).build(); + io.substrait.proto.Rel protoAggRel = to.toProto(aggRel); assertEquals( protoAggRel.getAggregate().getMeasuresList().get(0).getMeasure().getInvocation(), invocation.toProto()); @@ -70,7 +72,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio @Test void aggregateInvocationRoundtrip() { - for (var invocation : Expression.AggregationInvocation.values()) { + for (Expression.AggregationInvocation invocation : Expression.AggregationInvocation.values()) { assertAggregateRoundtrip(invocation); } } diff --git a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java index f82eedc12..9c434be17 100644 --- a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java @@ -17,7 +17,7 @@ public class ConsistentPartitionWindowRelRoundtripTest extends TestBase { @Test void consistentPartitionWindowRoundtripSingle() { - var windowFunctionDeclaration = + SimpleExtension.WindowFunctionVariant windowFunctionDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); @@ -70,11 +70,11 @@ void consistentPartitionWindowRoundtripSingle() { @Test void consistentPartitionWindowRoundtripMulti() { - var windowFunctionLeadDeclaration = + SimpleExtension.WindowFunctionVariant windowFunctionLeadDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); - var windowFunctionLagDeclaration = + SimpleExtension.WindowFunctionVariant windowFunctionLagDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 7b2237661..5b1ec5322 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -306,7 +306,7 @@ class ExtensionThroughExpression { @Test void scalarSubquery() { - var rel = + Project rel = b.project( input -> Stream.of( @@ -322,7 +322,7 @@ void scalarSubquery() { @Test void inPredicate() { - var rel = + Project rel = b.project( input -> Stream.of( @@ -337,7 +337,7 @@ void inPredicate() { @Test void setPredicate() { - var rel = + Project rel = b.project( input -> Stream.of( diff --git a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java index 1471154d8..47f53c110 100644 --- a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java @@ -51,8 +51,8 @@ public void roundtripTest(Method m, List paramInst, UnsupportedTypeGener // roundtrip to protobuff and back and check equality Expression val = (Expression) m.invoke(null, paramInst.toArray(new Object[0])); - var to = new ExpressionProtoConverter(null, null); - var from = + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = new ProtoExpressionConverter( null, null, diff --git a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java index cfcdaf6fc..7e613f272 100644 --- a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java @@ -25,8 +25,9 @@ void ifThenNotNullable() { ExpressionCreator.i64(false, 2)); assertFalse(ifRel.getType().nullable()); - var to = new ExpressionProtoConverter(null, null); - var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = + new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } @@ -39,8 +40,9 @@ void ifThenNullable() { ExpressionCreator.i64(false, 2)); assertTrue(ifRel.getType().nullable()); - var to = new ExpressionProtoConverter(null, null); - var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = + new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index b0f1b5fe3..ccac93bcb 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -15,9 +15,11 @@ public class LiteralRoundtripTest extends TestBase { @Test void decimal() { - var val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - var to = new ExpressionProtoConverter(null, null); - var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + io.substrait.expression.Expression.DecimalLiteral val = + ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = + new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); } } diff --git a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java index 8489860bb..f19535f65 100644 --- a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java @@ -20,7 +20,7 @@ public class LocalFilesRoundtripTest extends TestBase { private void assertLocalFilesRoundtrip(FileOrFiles file) { - var builder = + io.substrait.relation.ImmutableLocalFiles.Builder builder = LocalFiles.builder() .initialSchema( NamedStruct.builder() @@ -48,8 +48,8 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) { ExpressionCreator.i32(false, 1))) .ifPresent(builder::filter); - var localFiles = builder.build(); - var protoFileRel = relProtoConverter.toProto(localFiles); + io.substrait.relation.ImmutableLocalFiles localFiles = builder.build(); + io.substrait.proto.Rel protoFileRel = relProtoConverter.toProto(localFiles); assertTrue(protoFileRel.getRead().hasFilter()); assertEquals(protoFileRel, relProtoConverter.toProto(protoRelConverter.from(protoFileRel))); } @@ -57,37 +57,53 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) { private ImmutableFileOrFiles.Builder setPath( ImmutableFileOrFiles.Builder builder, ReadRel.LocalFiles.FileOrFiles.PathTypeCase pathTypeCase) { - return switch (pathTypeCase) { - case URI_PATH -> builder.pathType(FileOrFiles.PathType.URI_PATH).path("path"); - case URI_PATH_GLOB -> builder.pathType(FileOrFiles.PathType.URI_PATH_GLOB).path("path"); - case URI_FILE -> builder.pathType(FileOrFiles.PathType.URI_FILE).path("path"); - case URI_FOLDER -> builder.pathType(FileOrFiles.PathType.URI_FOLDER).path("path"); - case PATHTYPE_NOT_SET -> builder; - }; + switch (pathTypeCase) { + case URI_PATH: + return builder.pathType(FileOrFiles.PathType.URI_PATH).path("path"); + case URI_PATH_GLOB: + return builder.pathType(FileOrFiles.PathType.URI_PATH_GLOB).path("path"); + case URI_FILE: + return builder.pathType(FileOrFiles.PathType.URI_FILE).path("path"); + case URI_FOLDER: + return builder.pathType(FileOrFiles.PathType.URI_FOLDER).path("path"); + case PATHTYPE_NOT_SET: + return builder; + default: + throw new IllegalArgumentException("Unknown path type case: " + pathTypeCase); + } } private ImmutableFileOrFiles.Builder setFileFormat( ImmutableFileOrFiles.Builder builder, ReadRel.LocalFiles.FileOrFiles.FileFormatCase fileFormatCase) { - return switch (fileFormatCase) { - case PARQUET -> builder.fileFormat(FileFormat.ParquetReadOptions.builder().build()); - case ARROW -> builder.fileFormat(FileFormat.ArrowReadOptions.builder().build()); - case ORC -> builder.fileFormat(FileFormat.OrcReadOptions.builder().build()); - case DWRF -> builder.fileFormat(FileFormat.DwrfReadOptions.builder().build()); - case TEXT -> builder.fileFormat( - FileFormat.DelimiterSeparatedTextReadOptions.builder() - .fieldDelimiter("|") - .maxLineSize(1000) - .quote("\"") - .headerLinesToSkip(1) - .escape("\\") - .build()); - case EXTENSION -> builder.fileFormat( - FileFormat.Extension.builder() - .extension(com.google.protobuf.Any.newBuilder().build()) - .build()); - case FILEFORMAT_NOT_SET -> builder; - }; + switch (fileFormatCase) { + case PARQUET: + return builder.fileFormat(FileFormat.ParquetReadOptions.builder().build()); + case ARROW: + return builder.fileFormat(FileFormat.ArrowReadOptions.builder().build()); + case ORC: + return builder.fileFormat(FileFormat.OrcReadOptions.builder().build()); + case DWRF: + return builder.fileFormat(FileFormat.DwrfReadOptions.builder().build()); + case TEXT: + return builder.fileFormat( + FileFormat.DelimiterSeparatedTextReadOptions.builder() + .fieldDelimiter("|") + .maxLineSize(1000) + .quote("\"") + .headerLinesToSkip(1) + .escape("\\") + .build()); + case EXTENSION: + return builder.fileFormat( + FileFormat.Extension.builder() + .extension(com.google.protobuf.Any.newBuilder().build()) + .build()); + case FILEFORMAT_NOT_SET: + return builder; + default: + throw new IllegalArgumentException("Unknown file format case: " + fileFormatCase); + } } @Test diff --git a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java index 99a2e2042..923ab1ecc 100644 --- a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java @@ -15,11 +15,11 @@ public class ReadRelRoundtripTest extends TestBase { @Test void namedScan() { - var tableName = Stream.of("a_table").collect(Collectors.toList()); - var columnNames = Stream.of("column1", "column2").collect(Collectors.toList()); + List tableName = Stream.of("a_table").collect(Collectors.toList()); + List columnNames = Stream.of("column1", "column2").collect(Collectors.toList()); List columnTypes = Stream.of(R.I64, R.I64).collect(Collectors.toList()); - var namedScan = b.namedScan(tableName, columnNames, columnTypes); + NamedScan namedScan = b.namedScan(tableName, columnNames, columnTypes); namedScan = NamedScan.builder() .from(namedScan) @@ -33,13 +33,13 @@ void namedScan() { @Test void emptyScan() { - var emptyScan = b.emptyScan(); + io.substrait.relation.EmptyScan emptyScan = b.emptyScan(); verifyRoundTrip(emptyScan); } @Test void virtualTable() { - var virtTable = + io.substrait.relation.ImmutableVirtualTableScan virtTable = VirtualTableScan.builder() .initialSchema( NamedStruct.of( diff --git a/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java b/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java index eb990ca42..9bd2e734f 100644 --- a/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java +++ b/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java @@ -54,7 +54,7 @@ public void roundtrip(boolean n) { * @param type */ private void t(Type type) { - var converted = type.accept(typeProtoConverter); + io.substrait.proto.Type converted = type.accept(typeProtoConverter); assertEquals(type, protoTypeConverter.from(converted)); } diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index 96135999f..1dc8f755b 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -102,11 +102,12 @@ public Integer call() throws Exception { } private void printMessage(Message message) throws IOException { - switch (outputFormat) { - case PROTOJSON -> System.out.println( - JsonFormat.printer().includingDefaultValueFields().print(message)); - case PROTOTEXT -> TextFormat.printer().print(message, System.out); - case BINARY -> message.writeTo(System.out); + if (outputFormat == OutputFormat.PROTOJSON) { + System.out.println(JsonFormat.printer().includingDefaultValueFields().print(message)); + } else if (outputFormat == OutputFormat.PROTOTEXT) { + TextFormat.printer().print(message, System.out); + } else if (outputFormat == OutputFormat.BINARY) { + message.writeTo(System.out); } } diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index 01103e023..64f9891a4 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -93,7 +93,7 @@ jreleaser { } java { - toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } + toolchain { languageVersion = JavaLanguageVersion.of(11) } withJavadocJar() withSourcesJar() } @@ -149,8 +149,6 @@ dependencies { "calcite-core brings in commons-lang:commons-lang:2.4 which has a security vulnerability" ) } - annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") testImplementation("com.google.protobuf:protobuf-java:${PROTOBUF_VERSION}") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 6a8d84d88..6cba80781 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -38,7 +38,8 @@ public class AggregateFunctions { * conversion was needed, empty otherwise. */ public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) { - if (aggFunction instanceof SqlMinMaxAggFunction fun) { + if (aggFunction instanceof SqlMinMaxAggFunction) { + SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction; return Optional.of( fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX); } else if (aggFunction instanceof SqlAvgAggFunction) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java index ba18be900..f33eaa4c8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java @@ -91,9 +91,12 @@ private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) { } private static boolean isSimpleFieldReference(FunctionArg e) { - return e instanceof FieldReference fr - && fr.segments().size() == 1 - && fr.segments().get(0) instanceof FieldReference.StructField; + if (!(e instanceof FieldReference)) { + return false; + } + + List segments = ((FieldReference) e).segments(); + return segments.size() == 1 && segments.get(0) instanceof FieldReference.StructField; } private static int getFieldRefOffset(FieldReference fr) { @@ -194,8 +197,8 @@ private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) { } private Expression projectOutNonFieldReference(FunctionArg farg) { - if ((farg instanceof Expression e)) { - return projectOutNonFieldReference(e); + if ((farg instanceof Expression)) { + return projectOutNonFieldReference((Expression) farg); } else { throw new IllegalArgumentException("cannot handle non-expression argument for aggregate"); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java index 15591f06e..81c4e9a49 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java @@ -94,38 +94,38 @@ protected RelNodeVisitor() {} * RelVisitor.reverseAccept(RelNode) due to the lack of ability to extend base classes. */ public final OUTPUT reverseAccept(RelNode node) throws EXCEPTION { - if (node instanceof TableScan scan) { - return this.visit(scan); - } else if (node instanceof TableFunctionScan scan) { - return this.visit(scan); - } else if (node instanceof Values values) { - return this.visit(values); - } else if (node instanceof Filter filter) { - return this.visit(filter); - } else if (node instanceof Calc calc) { - return this.visit(calc); - } else if (node instanceof Project project) { - return this.visit(project); - } else if (node instanceof Join join) { - return this.visit(join); - } else if (node instanceof Correlate correlate) { - return this.visit(correlate); - } else if (node instanceof Union union) { - return this.visit(union); - } else if (node instanceof Intersect intersect) { - return this.visit(intersect); - } else if (node instanceof Minus minus) { - return this.visit(minus); - } else if (node instanceof Match match) { - return this.visit(match); - } else if (node instanceof Sort sort) { - return this.visit(sort); - } else if (node instanceof Exchange exchange) { - return this.visit(exchange); - } else if (node instanceof Aggregate aggregate) { - return this.visit(aggregate); - } else if (node instanceof TableModify modify) { - return this.visit(modify); + if (node instanceof TableScan) { + return this.visit((TableScan) node); + } else if (node instanceof TableFunctionScan) { + return this.visit((TableFunctionScan) node); + } else if (node instanceof Values) { + return this.visit((Values) node); + } else if (node instanceof Filter) { + return this.visit((Filter) node); + } else if (node instanceof Calc) { + return this.visit((Calc) node); + } else if (node instanceof Project) { + return this.visit((Project) node); + } else if (node instanceof Join) { + return this.visit((Join) node); + } else if (node instanceof Correlate) { + return this.visit((Correlate) node); + } else if (node instanceof Union) { + return this.visit((Union) node); + } else if (node instanceof Intersect) { + return this.visit((Intersect) node); + } else if (node instanceof Minus) { + return this.visit((Minus) node); + } else if (node instanceof Match) { + return this.visit((Match) node); + } else if (node instanceof Sort) { + return this.visit((Sort) node); + } else if (node instanceof Exchange) { + return this.visit((Exchange) node); + } else if (node instanceof Aggregate) { + return this.visit((Aggregate) node); + } else if (node instanceof TableModify) { + return this.visit((TableModify) node); } else { return this.visitOther(node); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index bb7b99f9c..5a5cac5dc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import com.github.bsideup.jabel.Desugar; import io.substrait.extendedexpression.ExtendedExpression; import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; import io.substrait.extension.SimpleExtension; @@ -45,12 +44,23 @@ public SqlExpressionToSubstrait( this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); } - @Desugar - private record Result( - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) {} + private static final class Result { + final SqlValidator validator; + final CalciteCatalogReader catalogReader; + final Map nameToTypeMap; + final Map nameToNodeMap; + + Result( + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) { + this.validator = validator; + this.catalogReader = catalogReader; + this.nameToTypeMap = nameToTypeMap; + this.nameToNodeMap = nameToNodeMap; + } + } /** * Converts the given SQL expression to an {@link io.substrait.proto.ExtendedExpression } @@ -78,10 +88,10 @@ public io.substrait.proto.ExtendedExpression convert( var result = registerCreateTablesForExtendedExpression(createStatements); return executeInnerSQLExpressions( sqlExpressions, - result.validator(), - result.catalogReader(), - result.nameToTypeMap(), - result.nameToNodeMap()); + result.validator, + result.catalogReader, + result.nameToTypeMap, + result.nameToNodeMap); } private io.substrait.proto.ExtendedExpression executeInnerSQLExpressions( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 0f9bd08a6..804244c22 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList; import io.substrait.expression.Expression; +import io.substrait.expression.Expression.SortDirection; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; @@ -16,6 +17,7 @@ import io.substrait.relation.Fetch; import io.substrait.relation.Filter; import io.substrait.relation.Join; +import io.substrait.relation.Join.JoinType; import io.substrait.relation.LocalFiles; import io.substrait.relation.NamedScan; import io.substrait.relation.Project; @@ -201,28 +203,47 @@ public RelNode visit(Join join, Context context) throws RuntimeException { join.getCondition() .map(c -> c.accept(expressionRexConverter, context)) .orElse(relBuilder.literal(true)); - var joinType = - switch (join.getJoinType()) { - case INNER -> JoinRelType.INNER; - case LEFT -> JoinRelType.LEFT; - case RIGHT -> JoinRelType.RIGHT; - case OUTER -> JoinRelType.FULL; - case SEMI -> JoinRelType.SEMI; - case ANTI -> JoinRelType.ANTI; - case LEFT_SEMI -> JoinRelType.SEMI; - case LEFT_ANTI -> JoinRelType.ANTI; - case UNKNOWN -> throw new UnsupportedOperationException( - "Unknown join type is not supported"); - default -> throw new UnsupportedOperationException( - "Unsupported join type: " + join.getJoinType().name()); - }; + var joinType = asJoinRelType(join); RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build(); return applyRemap(node, join.getRemap()); } + private JoinRelType asJoinRelType(Join join) { + Join.JoinType type = join.getJoinType(); + + if (type == JoinType.INNER) { + return JoinRelType.INNER; + } + if (type == JoinType.LEFT) { + return JoinRelType.LEFT; + } + if (type == JoinType.RIGHT) { + return JoinRelType.RIGHT; + } + if (type == JoinType.OUTER) { + return JoinRelType.FULL; + } + if (type == JoinType.SEMI) { + return JoinRelType.SEMI; + } + if (type == JoinType.ANTI) { + return JoinRelType.ANTI; + } + if (type == JoinType.LEFT_SEMI) { + return JoinRelType.SEMI; + } + if (type == JoinType.LEFT_ANTI) { + return JoinRelType.ANTI; + } + if (type == JoinType.UNKNOWN) { + throw new UnsupportedOperationException("Unknown join type is not supported"); + } + + throw new UnsupportedOperationException("Unsupported join type: " + join.getJoinType().name()); + } + @Override public RelNode visit(Set set, Context context) throws RuntimeException { - int numInputs = set.getInputs().size(); set.getInputs() .forEach( input -> { @@ -232,22 +253,36 @@ public RelNode visit(Set set, Context context) throws RuntimeException { // correspond to the Calcite relations they are associated with. They are retained for now // to enable users to migrate off of them. // See: https://github.com/substrait-io/substrait-java/issues/303 - var builder = - switch (set.getSetOp()) { - case MINUS_PRIMARY -> relBuilder.minus(false, numInputs); - case MINUS_PRIMARY_ALL, MINUS_MULTISET -> relBuilder.minus(true, numInputs); - case INTERSECTION_PRIMARY, INTERSECTION_MULTISET -> relBuilder.intersect( - false, numInputs); - case INTERSECTION_MULTISET_ALL -> relBuilder.intersect(true, numInputs); - case UNION_DISTINCT -> relBuilder.union(false, numInputs); - case UNION_ALL -> relBuilder.union(true, numInputs); - case UNKNOWN -> throw new UnsupportedOperationException( - "Unknown set operation is not supported"); - }; + var builder = getRelBuilder(set); RelNode node = builder.build(); return applyRemap(node, set.getRemap()); } + private RelBuilder getRelBuilder(Set set) { + int numInputs = set.getInputs().size(); + + switch (set.getSetOp()) { + case MINUS_PRIMARY: + return relBuilder.minus(false, numInputs); + case MINUS_PRIMARY_ALL: + case MINUS_MULTISET: + return relBuilder.minus(true, numInputs); + case INTERSECTION_PRIMARY: + case INTERSECTION_MULTISET: + return relBuilder.intersect(false, numInputs); + case INTERSECTION_MULTISET_ALL: + return relBuilder.intersect(true, numInputs); + case UNION_DISTINCT: + return relBuilder.union(false, numInputs); + case UNION_ALL: + return relBuilder.union(true, numInputs); + case UNKNOWN: + throw new UnsupportedOperationException("Unknown set operation is not supported"); + default: + throw new UnsupportedOperationException("Unsupported set operation: " + set.getSetOp()); + } + } + @Override public RelNode visit(Aggregate aggregate, Context context) throws RuntimeException { if (!PreCalciteAggregateValidator.isValidCalciteAggregate(aggregate)) { @@ -366,14 +401,24 @@ private RexNode directedRexNode(Expression.SortField sortField, Context context) var expression = sortField.expr(); var rexNode = expression.accept(expressionRexConverter, context); var sortDirection = sortField.direction(); - return switch (sortDirection) { - case ASC_NULLS_FIRST -> relBuilder.nullsFirst(rexNode); - case ASC_NULLS_LAST -> relBuilder.nullsLast(rexNode); - case DESC_NULLS_FIRST -> relBuilder.nullsFirst(relBuilder.desc(rexNode)); - case DESC_NULLS_LAST -> relBuilder.nullsLast(relBuilder.desc(rexNode)); - case CLUSTERED -> throw new RuntimeException( - String.format("Unexpected Expression.SortDirection: Clustered!")); - }; + + if (sortDirection == Expression.SortDirection.ASC_NULLS_FIRST) { + return relBuilder.nullsFirst(rexNode); + } + if (sortDirection == Expression.SortDirection.ASC_NULLS_LAST) { + return relBuilder.nullsLast(rexNode); + } + if (sortDirection == Expression.SortDirection.DESC_NULLS_FIRST) { + return relBuilder.nullsFirst(relBuilder.desc(rexNode)); + } + if (sortDirection == Expression.SortDirection.DESC_NULLS_LAST) { + return relBuilder.nullsLast(relBuilder.desc(rexNode)); + } + if (sortDirection == Expression.SortDirection.CLUSTERED) { + throw new RuntimeException(String.format("Unexpected Expression.SortDirection: Clustered!")); + } + + throw new IllegalArgumentException("Unsupported sort direction: " + sortDirection); } @Override @@ -398,24 +443,30 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField, Co var sortDirection = sortField.direction(); RexSlot rexSlot = (RexSlot) rex; int fieldIndex = rexSlot.getIndex(); - var fieldDirection = RelFieldCollation.Direction.ASCENDING; - var nullDirection = RelFieldCollation.NullDirection.UNSPECIFIED; - switch (sortDirection) { - case ASC_NULLS_FIRST -> nullDirection = RelFieldCollation.NullDirection.FIRST; - case ASC_NULLS_LAST -> nullDirection = RelFieldCollation.NullDirection.LAST; - case DESC_NULLS_FIRST -> { - nullDirection = RelFieldCollation.NullDirection.FIRST; - fieldDirection = RelFieldCollation.Direction.DESCENDING; - } - case DESC_NULLS_LAST -> { - nullDirection = RelFieldCollation.NullDirection.LAST; - fieldDirection = RelFieldCollation.Direction.DESCENDING; - } - case CLUSTERED -> fieldDirection = RelFieldCollation.Direction.CLUSTERED; - - default -> throw new RuntimeException( + + final RelFieldCollation.Direction fieldDirection; + final RelFieldCollation.NullDirection nullDirection; + + if (sortDirection == SortDirection.ASC_NULLS_FIRST) { + fieldDirection = RelFieldCollation.Direction.ASCENDING; + nullDirection = RelFieldCollation.NullDirection.FIRST; + } else if (sortDirection == SortDirection.ASC_NULLS_LAST) { + fieldDirection = RelFieldCollation.Direction.ASCENDING; + nullDirection = RelFieldCollation.NullDirection.LAST; + } else if (sortDirection == SortDirection.DESC_NULLS_FIRST) { + nullDirection = RelFieldCollation.NullDirection.FIRST; + fieldDirection = RelFieldCollation.Direction.DESCENDING; + } else if (sortDirection == SortDirection.DESC_NULLS_LAST) { + nullDirection = RelFieldCollation.NullDirection.LAST; + fieldDirection = RelFieldCollation.Direction.DESCENDING; + } else if (sortDirection == SortDirection.CLUSTERED) { + fieldDirection = RelFieldCollation.Direction.CLUSTERED; + nullDirection = RelFieldCollation.NullDirection.UNSPECIFIED; + } else { + throw new RuntimeException( String.format("Unexpected Expression.SortDirection enum: %s !", sortDirection)); } + return new RelFieldCollation(fieldIndex, fieldDirection, nullDirection); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index fdb3f8aec..2e31718f9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -32,9 +32,11 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelFieldCollation.Direction; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; @@ -170,17 +172,7 @@ public Rel visit(org.apache.calcite.rel.core.Join join) { var left = apply(join.getLeft()); var right = apply(join.getRight()); var condition = toExpression(join.getCondition()); - var joinType = - switch (join.getJoinType()) { - case INNER -> Join.JoinType.INNER; - case LEFT -> Join.JoinType.LEFT; - case RIGHT -> Join.JoinType.RIGHT; - case FULL -> Join.JoinType.OUTER; - case SEMI -> Join.JoinType.LEFT_SEMI; - case ANTI -> Join.JoinType.LEFT_ANTI; - default -> throw new UnsupportedOperationException( - "Unsupported join type: " + join.getJoinType()); - }; + var joinType = asJoinType(join); // An INNER JOIN with a join condition of TRUE can be encoded as a Substrait Cross relation if (joinType == Join.JoinType.INNER && TRUE.equals(condition)) { @@ -189,6 +181,26 @@ public Rel visit(org.apache.calcite.rel.core.Join join) { return Join.builder().condition(condition).joinType(joinType).left(left).right(right).build(); } + private Join.JoinType asJoinType(org.apache.calcite.rel.core.Join join) { + JoinRelType type = join.getJoinType(); + + if (type == JoinRelType.INNER) { + return Join.JoinType.INNER; + } else if (type == JoinRelType.LEFT) { + return Join.JoinType.LEFT; + } else if (type == JoinRelType.RIGHT) { + return Join.JoinType.RIGHT; + } else if (type == JoinRelType.FULL) { + return Join.JoinType.OUTER; + } else if (type == JoinRelType.SEMI) { + return Join.JoinType.LEFT_SEMI; + } else if (type == JoinRelType.ANTI) { + return Join.JoinType.LEFT_ANTI; + } + + throw new UnsupportedOperationException("Unsupported join type: " + join.getJoinType()); + } + @Override public Rel visit(org.apache.calcite.rel.core.Correlate correlate) { // left input of correlated-join is similar to the left input of a logical join @@ -197,16 +209,22 @@ public Rel visit(org.apache.calcite.rel.core.Correlate correlate) { // right input of correlated-join is similar to a correlated sub-query apply(correlate.getRight()); - var joinType = - switch (correlate.getJoinType()) { - case INNER -> Join.JoinType.INNER; // corresponds to CROSS APPLY join - case LEFT -> Join.JoinType.LEFT; // corresponds to OUTER APPLY join - default -> throw new IllegalArgumentException( - "Invalid correlated join type: " + correlate.getJoinType()); - }; + var joinType = asJoinType(correlate); return super.visit(correlate); } + private Join.JoinType asJoinType(org.apache.calcite.rel.core.Correlate correlate) { + JoinRelType type = correlate.getJoinType(); + + if (type == JoinRelType.INNER) { + return Join.JoinType.INNER; + } else if (type == JoinRelType.LEFT) { + return Join.JoinType.LEFT; + } + + throw new IllegalArgumentException("Invalid correlated join type: " + correlate.getJoinType()); + } + @Override public Rel visit(org.apache.calcite.rel.core.Union union) { var inputs = apply(union.getInputs()); @@ -316,28 +334,17 @@ public Rel visit(org.apache.calcite.rel.core.Sort sort) { private long asLong(RexNode rex) { var expr = toExpression(rex); - if (expr instanceof Expression.I64Literal i) { - return i.value(); - } else if (expr instanceof Expression.I32Literal i) { - return i.value(); + if (expr instanceof Expression.I64Literal) { + return ((Expression.I64Literal) expr).value(); + } else if (expr instanceof Expression.I32Literal) { + return ((Expression.I32Literal) expr).value(); } throw new UnsupportedOperationException("Unknown type: " + rex); } public static Expression.SortField toSortField( RelFieldCollation collation, Type.Struct inputType) { - Expression.SortDirection direction = - switch (collation.direction) { - case STRICTLY_ASCENDING, ASCENDING -> collation.nullDirection - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.ASC_NULLS_LAST - : Expression.SortDirection.ASC_NULLS_FIRST; - case STRICTLY_DESCENDING, DESCENDING -> collation.nullDirection - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.DESC_NULLS_LAST - : Expression.SortDirection.DESC_NULLS_FIRST; - case CLUSTERED -> Expression.SortDirection.CLUSTERED; - }; + Expression.SortDirection direction = asSortDirection(collation); return Expression.SortField.builder() .expr(FieldReference.newRootStructReference(collation.getFieldIndex(), inputType)) @@ -345,6 +352,24 @@ public static Expression.SortField toSortField( .build(); } + private static Expression.SortDirection asSortDirection(RelFieldCollation collation) { + RelFieldCollation.Direction direction = collation.direction; + + if (direction == Direction.STRICTLY_ASCENDING || direction == Direction.ASCENDING) { + return collation.nullDirection == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.ASC_NULLS_LAST + : Expression.SortDirection.ASC_NULLS_FIRST; + } else if (direction == Direction.STRICTLY_DESCENDING || direction == Direction.DESCENDING) { + return collation.nullDirection == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.DESC_NULLS_LAST + : Expression.SortDirection.DESC_NULLS_FIRST; + } else if (direction == Direction.CLUSTERED) { + return Expression.SortDirection.CLUSTERED; + } + + throw new IllegalArgumentException("Unsupported collation direction: " + direction); + } + @Override public Rel visit(org.apache.calcite.rel.core.Exchange exchange) { return super.visit(exchange); diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index c48943ca1..ef1e919ee 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -65,63 +65,90 @@ private Type toSubstrait(RelDataType type, List names) { } TypeCreator creator = Type.withNullability(type.isNullable()); - return switch (type.getSqlTypeName()) { - case BOOLEAN -> creator.BOOLEAN; - case TINYINT -> creator.I8; - case SMALLINT -> creator.I16; - case INTEGER -> creator.I32; - case BIGINT -> creator.I64; - case REAL -> creator.FP32; - case FLOAT, DOUBLE -> creator.FP64; - case DECIMAL -> { - if (type.getPrecision() > 38) { - throw new UnsupportedOperationException( - "unsupported decimal precision " + type.getPrecision()); + + switch (type.getSqlTypeName()) { + case BOOLEAN: + return creator.BOOLEAN; + case TINYINT: + return creator.I8; + case SMALLINT: + return creator.I16; + case INTEGER: + return creator.I32; + case BIGINT: + return creator.I64; + case REAL: + return creator.FP32; + case FLOAT: + case DOUBLE: + return creator.FP64; + case DECIMAL: + { + if (type.getPrecision() > 38) { + throw new UnsupportedOperationException( + "unsupported decimal precision " + type.getPrecision()); + } + return creator.decimal(type.getPrecision(), type.getScale()); } - yield creator.decimal(type.getPrecision(), type.getScale()); - } - case CHAR -> creator.fixedChar(type.getPrecision()); - case VARCHAR -> { - if (type.getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { - yield creator.STRING; + case CHAR: + return creator.fixedChar(type.getPrecision()); + case VARCHAR: + { + if (type.getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { + return creator.STRING; + } + return creator.varChar(type.getPrecision()); } - yield creator.varChar(type.getPrecision()); - } - case SYMBOL -> creator.STRING; - case DATE -> creator.DATE; - case TIME -> creator.TIME; - case TIMESTAMP -> creator.precisionTimestamp(type.getPrecision()); - case TIMESTAMP_WITH_LOCAL_TIME_ZONE -> creator.precisionTimestampTZ(type.getPrecision()); - case INTERVAL_YEAR, INTERVAL_YEAR_MONTH, INTERVAL_MONTH -> creator.INTERVAL_YEAR; - case INTERVAL_DAY, - INTERVAL_DAY_HOUR, - INTERVAL_DAY_MINUTE, - INTERVAL_DAY_SECOND, - INTERVAL_HOUR, - INTERVAL_HOUR_MINUTE, - INTERVAL_HOUR_SECOND, - INTERVAL_MINUTE, - INTERVAL_MINUTE_SECOND, - INTERVAL_SECOND -> creator.intervalDay(type.getScale()); - case VARBINARY -> creator.BINARY; - case BINARY -> creator.fixedBinary(type.getPrecision()); - case MAP -> { - MapSqlType map = (MapSqlType) type; - yield creator.map( - toSubstrait(map.getKeyType(), names), toSubstrait(map.getValueType(), names)); - } - case ROW -> { - var children = new ArrayList(); - for (var field : type.getFieldList()) { - names.add(field.getName()); - children.add(toSubstrait(field.getType(), names)); + case SYMBOL: + return creator.STRING; + case DATE: + return creator.DATE; + case TIME: + return creator.TIME; + case TIMESTAMP: + return creator.precisionTimestamp(type.getPrecision()); + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + return creator.precisionTimestampTZ(type.getPrecision()); + case INTERVAL_YEAR: + case INTERVAL_YEAR_MONTH: + case INTERVAL_MONTH: + return creator.INTERVAL_YEAR; + case INTERVAL_DAY: + case INTERVAL_DAY_HOUR: + case INTERVAL_DAY_MINUTE: + case INTERVAL_DAY_SECOND: + case INTERVAL_HOUR: + case INTERVAL_HOUR_MINUTE: + case INTERVAL_HOUR_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_MINUTE_SECOND: + case INTERVAL_SECOND: + return creator.intervalDay(type.getScale()); + case VARBINARY: + return creator.BINARY; + case BINARY: + return creator.fixedBinary(type.getPrecision()); + case MAP: + { + MapSqlType map = (MapSqlType) type; + return creator.map( + toSubstrait(map.getKeyType(), names), toSubstrait(map.getValueType(), names)); } - yield creator.struct(children); - } - case ARRAY -> creator.list(toSubstrait(type.getComponentType(), names)); - default -> throw new UnsupportedOperationException( - String.format("Unable to convert the type " + type.toString())); - }; + case ROW: + { + var children = new ArrayList(); + for (var field : type.getFieldList()) { + names.add(field.getName()); + children.add(toSubstrait(field.getType(), names)); + } + return creator.struct(children); + } + case ARRAY: + return creator.list(toSubstrait(type.getComponentType(), names)); + default: + throw new UnsupportedOperationException( + String.format("Unable to convert the type " + type.toString())); + } } public RelDataType toCalcite( @@ -343,14 +370,17 @@ private boolean n(NullableType type) { } private RelDataType t(boolean nullable, SqlTypeName typeName, Integer... props) { - final RelDataType baseType = - switch (props.length) { - case 0 -> typeFactory.createSqlType(typeName); - case 1 -> typeFactory.createSqlType(typeName, props[0]); - case 2 -> typeFactory.createSqlType(typeName, props[0], props[1]); - default -> throw new IllegalArgumentException( - "Unexpected properties length: " + Arrays.toString(props)); - }; + final RelDataType baseType; + if (props.length == 0) { + baseType = typeFactory.createSqlType(typeName); + } else if (props.length == 1) { + baseType = typeFactory.createSqlType(typeName, props[0]); + } else if (props.length == 2) { + baseType = typeFactory.createSqlType(typeName, props[0], props[1]); + } else { + throw new IllegalArgumentException( + "Unexpected properties length: " + Arrays.toString(props)); + } return typeFactory.createTypeWithNullability(baseType, nullable); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 29039be8d..4009df779 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -108,7 +108,7 @@ protected FunctionFinder getFunctionFinder(AggregateCall call) { return signatures.get(lookupFunction); } - static class WrappedAggregateCall implements GenericCall { + static class WrappedAggregateCall implements FunctionConverter.GenericCall { private final AggregateCall call; private final RelNode input; private final RexBuilder rexBuilder; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 36fc68b37..810754e54 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -66,8 +66,11 @@ public class CallConverters { // For now, we only support handling of SqlKind.REINTEPRETET for the case of stored // user-defined literals - if (operand instanceof Expression.FixedBinaryLiteral literal - && type instanceof Type.UserDefined t) { + if (operand instanceof Expression.FixedBinaryLiteral + && type instanceof Type.UserDefined) { + Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; + Type.UserDefined t = (Type.UserDefined) type; + return Expression.UserDefinedLiteral.builder() .uri(t.uri()) .name(t.name()) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index d962dc63a..cec68fcb7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -117,8 +117,8 @@ private static Optional findEnumArg( return Optional.empty(); } Argument arg = args.get(enumAnchor.argIdx); - if (arg instanceof SimpleExtension.EnumArgument ea) { - return Optional.of(ea); + if (arg instanceof SimpleExtension.EnumArgument) { + return Optional.of((SimpleExtension.EnumArgument) arg); } else { return Optional.empty(); } @@ -127,19 +127,21 @@ private static Optional findEnumArg( static Optional fromRex( SimpleExtension.Function function, RexLiteral literal, int argIdx) { - return switch (literal.getType().getSqlTypeName()) { - case SYMBOL -> { - Object v = literal.getValue(); - if (!literal.isNull() && (v instanceof Enum)) { - Enum value = (Enum) v; - ArgAnchor enumAnchor = argAnchor(function, argIdx); - yield findEnumArg(function, enumAnchor).map(ea -> EnumArg.of(ea, value.name())); - } else { - yield Optional.empty(); + switch (literal.getType().getSqlTypeName()) { + case SYMBOL: + { + Object v = literal.getValue(); + if (!literal.isNull() && (v instanceof Enum)) { + Enum value = (Enum) v; + ArgAnchor enumAnchor = argAnchor(function, argIdx); + return findEnumArg(function, enumAnchor).map(ea -> EnumArg.of(ea, value.name())); + } + + return Optional.empty(); } - } - default -> Optional.empty(); - }; + default: + return Optional.empty(); + } } static boolean canConvert(Enum value) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index bd433ccbe..e6d03b663 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -424,15 +424,7 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) expr.sort().stream() .map( sf -> { - Set direction = - switch (sf.direction()) { - case ASC_NULLS_FIRST -> Set.of(SqlKind.NULLS_FIRST); - case ASC_NULLS_LAST -> Set.of(SqlKind.NULLS_LAST); - case DESC_NULLS_FIRST -> Set.of(SqlKind.DESCENDING, SqlKind.NULLS_FIRST); - case DESC_NULLS_LAST -> Set.of(SqlKind.DESCENDING, SqlKind.NULLS_LAST); - case CLUSTERED -> throw new IllegalArgumentException( - "SORT_DIRECTION_CLUSTERED is not supported"); - }; + Set direction = asSqlKind(sf.direction()); return new RexFieldCollation(sf.expr().accept(this, context), direction); }) .collect(ImmutableList.toImmutableList()); @@ -440,19 +432,8 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) RexWindowBound lowerBound = ToRexWindowBound.lowerBound(rexBuilder, expr.lowerBound()); RexWindowBound upperBound = ToRexWindowBound.upperBound(rexBuilder, expr.upperBound()); - boolean rowMode = - switch (expr.boundsType()) { - case ROWS -> true; - case RANGE -> false; - case UNSPECIFIED -> throw new IllegalArgumentException( - "bounds type on window function must be specified"); - }; - - boolean distinct = - switch (expr.invocation()) { - case UNSPECIFIED, ALL -> false; - case DISTINCT -> true; - }; + boolean rowMode = isRowMode(expr); + boolean distinct = isDistinct(expr); // For queries like: SELECT last_value() IGNORE NULLS OVER ... // Substrait has no mechanism to set this, so by default it is false @@ -478,6 +459,53 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) ignoreNulls); } + private Set asSqlKind(Expression.SortDirection direction) { + switch (direction) { + case ASC_NULLS_FIRST: + return Set.of(SqlKind.NULLS_FIRST); + case ASC_NULLS_LAST: + return Set.of(SqlKind.NULLS_LAST); + case DESC_NULLS_FIRST: + return Set.of(SqlKind.DESCENDING, SqlKind.NULLS_FIRST); + case DESC_NULLS_LAST: + return Set.of(SqlKind.DESCENDING, SqlKind.NULLS_LAST); + case CLUSTERED: + throw new IllegalArgumentException("SORT_DIRECTION_CLUSTERED is not supported"); + default: + throw new IllegalArgumentException("Unsupported sort direction: " + direction); + } + } + + private boolean isRowMode(Expression.WindowFunctionInvocation expr) { + Expression.WindowBoundsType boundsType = expr.boundsType(); + + switch (boundsType) { + case ROWS: + return true; + case RANGE: + return false; + case UNSPECIFIED: + throw new IllegalArgumentException("bounds type on window function must be specified"); + default: + throw new IllegalArgumentException( + "Unsupported window function bounds type: " + boundsType); + } + } + + private boolean isDistinct(Expression.WindowFunctionInvocation expr) { + Expression.AggregationInvocation invocation = expr.invocation(); + + switch (invocation) { + case UNSPECIFIED: + case ALL: + return false; + case DISTINCT: + return true; + default: + throw new IllegalArgumentException("Unsupported window function invocation: " + invocation); + } + } + @Override public RexNode visit(Expression.InPredicate expr, Context context) throws RuntimeException { List needles = @@ -534,12 +562,12 @@ public RexWindowBound visit(WindowBound.Unbounded unbounded) { private String convert(FunctionArg a) { String v; - if (a instanceof EnumArg ea) { - v = ea.value().toString(); - } else if (a instanceof Expression e) { - v = e.getType().accept(new StringTypeVisitor()); - } else if (a instanceof Type t) { - v = t.accept(new StringTypeVisitor()); + if (a instanceof EnumArg) { + v = ((EnumArg) a).value().toString(); + } else if (a instanceof Expression) { + v = ((Expression) a).getType().accept(new StringTypeVisitor()); + } else if (a instanceof Type) { + v = ((Type) a).accept(new StringTypeVisitor()); } else { throw new IllegalStateException("Unexpected value: " + a); } @@ -561,7 +589,8 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti var segment = expr.segments().get(0); RexInputRef rexInputRef; - if (segment instanceof FieldReference.StructField f) { + if (segment instanceof FieldReference.StructField) { + FieldReference.StructField f = (FieldReference.StructField) segment; rexInputRef = new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); } else { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java index ab6233f54..1cb057527 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java @@ -92,14 +92,14 @@ public Optional convert( } private Optional toInt(Expression.Literal l) { - if (l instanceof Expression.I8Literal i8) { - return Optional.of(i8.value()); - } else if (l instanceof Expression.I16Literal i16) { - return Optional.of(i16.value()); - } else if (l instanceof Expression.I32Literal i32) { - return Optional.of(i32.value()); - } else if (l instanceof Expression.I64Literal i64) { - return Optional.of((int) i64.value()); + if (l instanceof Expression.I8Literal) { + return Optional.of(((Expression.I8Literal) l).value()); + } else if (l instanceof Expression.I16Literal) { + return Optional.of(((Expression.I16Literal) l).value()); + } else if (l instanceof Expression.I32Literal) { + return Optional.of(((Expression.I32Literal) l).value()); + } else if (l instanceof Expression.I64Literal) { + return Optional.of((int) ((Expression.I64Literal) l).value()); } logger.atWarn().log("Literal expected to be int type but was not. {}.", l); return Optional.empty(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 399c63b3e..b2ad4c5dc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -203,7 +203,7 @@ && inputTypesMatchDefinedArguments(inputTypes, args)) { * @param args expected arguments as defined in a {@link SimpleExtension.Function} * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise */ - private static boolean inputTypesMatchDefinedArguments( + private boolean inputTypesMatchDefinedArguments( List inputTypes, List args) { Map> wildcardToType = new HashMap<>(); @@ -242,9 +242,8 @@ private static boolean inputTypesMatchDefinedArguments( *

If this exists, the function finder will attempt to find a least-restrictive match using * these. */ - private static - Optional> getSingularInputType(List functions) { - List matchers = new ArrayList<>(); + private Optional> getSingularInputType(List functions) { + List> matchers = new ArrayList<>(); for (var f : functions) { ParameterizedType firstType = null; @@ -274,15 +273,17 @@ Optional> getSingularInputType(List functions) { } } - return switch (matchers.size()) { - case 0 -> Optional.empty(); - case 1 -> Optional.of(matchers.get(0)); - default -> Optional.of(chained(matchers)); - }; + switch (matchers.size()) { + case 0: + return Optional.empty(); + case 1: + return Optional.of(matchers.get(0)); + default: + return Optional.of(chained(matchers)); + } } - public static SingularArgumentMatcher singular( - F function, ParameterizedType type) { + private SingularArgumentMatcher singular(F function, ParameterizedType type) { return (inputType, outputType) -> { var check = isMatch(inputType, type); if (check) { @@ -292,7 +293,7 @@ public static SingularArgumentMatcher si }; } - public static SingularArgumentMatcher chained(List matchers) { + private SingularArgumentMatcher chained(List> matchers) { return (inputType, outputType) -> { for (var s : matchers) { var outcome = s.tryMatch(inputType, outputType); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 47f63908d..b5aee4168 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -89,112 +89,139 @@ public Expression.Literal convert(RexLiteral literal) { return typedNull(type); } - return switch (literal.getType().getSqlTypeName()) { - case TINYINT -> i8(n, i(literal).intValue()); - case SMALLINT -> i16(n, i(literal).intValue()); - case INTEGER -> i32(n, i(literal).intValue()); - case BIGINT -> i64(n, i(literal).longValue()); - case BOOLEAN -> bool(n, literal.getValueAs(Boolean.class)); - case CHAR -> { - var val = literal.getValue(); - if (val instanceof NlsString nls) { - yield fixedChar(n, nls.getValue()); + switch (literal.getType().getSqlTypeName()) { + case TINYINT: + return i8(n, i(literal).intValue()); + case SMALLINT: + return i16(n, i(literal).intValue()); + case INTEGER: + return i32(n, i(literal).intValue()); + case BIGINT: + return i64(n, i(literal).longValue()); + case BOOLEAN: + return bool(n, literal.getValueAs(Boolean.class)); + case CHAR: + { + var val = literal.getValue(); + if (val instanceof NlsString) { + var nls = (NlsString) val; + return fixedChar(n, nls.getValue()); + } + throw new UnsupportedOperationException("Unable to handle char type: " + val); } - throw new UnsupportedOperationException("Unable to handle char type: " + val); - } - case FLOAT, DOUBLE -> fp64(n, literal.getValueAs(Double.class)); - case REAL -> fp32(n, literal.getValueAs(Float.class)); - - case DECIMAL -> { - BigDecimal bd = bd(literal); - yield decimal(n, bd, literal.getType().getPrecision(), literal.getType().getScale()); - } - case VARCHAR -> { - if (literal.getType().getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { - yield string(n, s(literal)); + case FLOAT: + case DOUBLE: + return fp64(n, literal.getValueAs(Double.class)); + case REAL: + return fp32(n, literal.getValueAs(Float.class)); + + case DECIMAL: + { + BigDecimal bd = bd(literal); + return decimal(n, bd, literal.getType().getPrecision(), literal.getType().getScale()); } + case VARCHAR: + { + if (literal.getType().getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { + return string(n, s(literal)); + } - yield varChar(n, s(literal), literal.getType().getPrecision()); - } - case BINARY -> fixedBinary( - n, - ByteString.copyFrom( - padRightIfNeeded( - literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class), - literal.getType().getPrecision()))); - case VARBINARY -> binary(n, ByteString.copyFrom(literal.getValueAs(byte[].class))); - case SYMBOL -> { - Object value = literal.getValue(); - // case TimeUnitRange tur -> string(n, tur.name()); - if (value instanceof NlsString s) { - yield string(n, s.getValue()); - } else if (value instanceof Enum v) { - Optional r = - EnumConverter.canConvert(v) ? Optional.of(string(n, v.name())) : Optional.empty(); - yield r.orElseThrow( - () -> new UnsupportedOperationException("Unable to handle symbol: " + value)); - } else { - throw new UnsupportedOperationException("Unable to handle symbol: " + value); + return varChar(n, s(literal), literal.getType().getPrecision()); } - } - case DATE -> { - DateString date = literal.getValueAs(DateString.class); - LocalDate localDate = LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER); - yield ExpressionCreator.date(n, (int) localDate.toEpochDay()); - } - case TIME -> { - TimeString time = literal.getValueAs(TimeString.class); - LocalTime localTime = LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER); - yield time(n, NANOSECONDS.toMicros(localTime.toNanoOfDay())); - } - case TIMESTAMP, TIMESTAMP_WITH_LOCAL_TIME_ZONE -> { - TimestampString timestamp = literal.getValueAs(TimestampString.class); - LocalDateTime ldt = - LocalDateTime.parse(timestamp.toString(), CALCITE_LOCAL_DATETIME_FORMATTER); - yield timestamp(n, ldt); - } - case INTERVAL_YEAR, INTERVAL_YEAR_MONTH, INTERVAL_MONTH -> { - long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class)); - var years = intervalLength / 12; - var months = intervalLength - years * 12; - yield intervalYear(n, (int) years, (int) months); - } - case INTERVAL_DAY, - INTERVAL_DAY_HOUR, - INTERVAL_DAY_MINUTE, - INTERVAL_DAY_SECOND, - INTERVAL_HOUR, - INTERVAL_HOUR_MINUTE, - INTERVAL_HOUR_SECOND, - INTERVAL_MINUTE, - INTERVAL_MINUTE_SECOND, - INTERVAL_SECOND -> { - // Calcite represents day/time intervals in milliseconds, despite a default scale of 6. - var totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class)); - var interval = Duration.ofMillis(totalMillis); - - var days = interval.toDays(); - var seconds = interval.minusDays(days).toSeconds(); - var micros = interval.toMillisPart() * 1000; - - yield intervalDay(n, (int) days, (int) seconds, micros, 6); - } - - case ROW -> { - List literals = (List) literal.getValue(); - yield struct(n, literals.stream().map(this::convert).collect(Collectors.toList())); - } - - case ARRAY -> { - List literals = (List) literal.getValue(); - yield list(n, literals.stream().map(this::convert).collect(Collectors.toList())); - } - - default -> throw new UnsupportedOperationException( - String.format( - "Unable to convert the value of %s of type %s to a literal.", - literal, literal.getType().getSqlTypeName())); - }; + case BINARY: + return fixedBinary( + n, + ByteString.copyFrom( + padRightIfNeeded( + literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class), + literal.getType().getPrecision()))); + case VARBINARY: + return binary(n, ByteString.copyFrom(literal.getValueAs(byte[].class))); + case SYMBOL: + { + Object value = literal.getValue(); + // case TimeUnitRange tur -> string(n, tur.name()); + if (value instanceof NlsString) { + return string(n, ((NlsString) value).getValue()); + } else if (value instanceof Enum) { + Enum v = (Enum) value; + Optional r = + EnumConverter.canConvert(v) ? Optional.of(string(n, v.name())) : Optional.empty(); + return r.orElseThrow( + () -> new UnsupportedOperationException("Unable to handle symbol: " + value)); + } else { + throw new UnsupportedOperationException("Unable to handle symbol: " + value); + } + } + case DATE: + { + DateString date = literal.getValueAs(DateString.class); + LocalDate localDate = LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER); + return ExpressionCreator.date(n, (int) localDate.toEpochDay()); + } + case TIME: + { + TimeString time = literal.getValueAs(TimeString.class); + LocalTime localTime = LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER); + return time(n, NANOSECONDS.toMicros(localTime.toNanoOfDay())); + } + case TIMESTAMP: + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + { + TimestampString timestamp = literal.getValueAs(TimestampString.class); + LocalDateTime ldt = + LocalDateTime.parse(timestamp.toString(), CALCITE_LOCAL_DATETIME_FORMATTER); + return timestamp(n, ldt); + } + case INTERVAL_YEAR: + case INTERVAL_YEAR_MONTH: + case INTERVAL_MONTH: + { + long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class)); + var years = intervalLength / 12; + var months = intervalLength - years * 12; + return intervalYear(n, (int) years, (int) months); + } + case INTERVAL_DAY: + case INTERVAL_DAY_HOUR: + case INTERVAL_DAY_MINUTE: + case INTERVAL_DAY_SECOND: + case INTERVAL_HOUR: + case INTERVAL_HOUR_MINUTE: + case INTERVAL_HOUR_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_MINUTE_SECOND: + case INTERVAL_SECOND: + { + // Calcite represents day/time intervals in milliseconds, despite a default scale of 6. + var totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class)); + var interval = Duration.ofMillis(totalMillis); + + var days = interval.toDays(); + var seconds = interval.minusDays(days).toSeconds(); + var micros = interval.toMillisPart() * 1000; + + return intervalDay(n, (int) days, (int) seconds, micros, 6); + } + + case ROW: + { + List literals = (List) literal.getValue(); + return struct(n, literals.stream().map(this::convert).collect(Collectors.toList())); + } + + case ARRAY: + { + List literals = (List) literal.getValue(); + return list(n, literals.stream().map(this::convert).collect(Collectors.toList())); + } + + default: + throw new UnsupportedOperationException( + String.format( + "Unable to convert the value of %s of type %s to a literal.", + literal, literal.getType().getSqlTypeName())); + } } public static byte[] padRightIfNeeded( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 2bc7ec534..a02e7fbac 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -129,25 +129,30 @@ public Expression visitRangeRef(RexRangeRef rangeRef) { public Expression visitFieldAccess(RexFieldAccess fieldAccess) { SqlKind kind = fieldAccess.getReferenceExpr().getKind(); switch (kind) { - case CORREL_VARIABLE -> { - int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess); - - return FieldReference.newRootStructOuterReference( - fieldAccess.getField().getIndex(), - typeConverter.toSubstrait(fieldAccess.getType()), - stepsOut); - } - case ITEM, INPUT_REF, FIELD_ACCESS -> { - Expression expression = fieldAccess.getReferenceExpr().accept(this); - if (expression instanceof FieldReference) { - FieldReference nestedReference = (FieldReference) expression; - return nestedReference.dereferenceStruct(fieldAccess.getField().getIndex()); - } else { - return FieldReference.newStructReference(fieldAccess.getField().getIndex(), expression); + case CORREL_VARIABLE: + { + int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess); + + return FieldReference.newRootStructOuterReference( + fieldAccess.getField().getIndex(), + typeConverter.toSubstrait(fieldAccess.getType()), + stepsOut); } - } - default -> throw new UnsupportedOperationException( - String.format("RexFieldAccess for SqlKind %s not supported", kind)); + case ITEM: + case INPUT_REF: + case FIELD_ACCESS: + { + Expression expression = fieldAccess.getReferenceExpr().accept(this); + if (expression instanceof FieldReference) { + FieldReference nestedReference = (FieldReference) expression; + return nestedReference.dereferenceStruct(fieldAccess.getField().getIndex()); + } else { + return FieldReference.newStructReference(fieldAccess.getField().getIndex(), expression); + } + } + default: + throw new UnsupportedOperationException( + String.format("RexFieldAccess for SqlKind %s not supported", kind)); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java index c80fd2995..53879fdc7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java @@ -135,7 +135,7 @@ private Optional> getMappedExpressionArguments( .orElse(Optional.empty()); } - protected static class WrappedScalarCall implements GenericCall { + protected static class WrappedScalarCall implements FunctionConverter.GenericCall { private final RexCall delegate; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java index 9cfbf9db5..460c825b7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelFieldCollation.Direction; import org.apache.calcite.rex.RexFieldCollation; public class SortFieldConverter { @@ -10,23 +11,27 @@ public class SortFieldConverter { public static Expression.SortField toSortField( RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) { var expr = rexFieldCollation.left.accept(rexExpressionConverter); - var rexDirection = rexFieldCollation.getDirection(); - Expression.SortDirection direction = - switch (rexDirection) { - case ASCENDING -> rexFieldCollation.getNullDirection() - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.ASC_NULLS_LAST - : Expression.SortDirection.ASC_NULLS_FIRST; - case DESCENDING -> rexFieldCollation.getNullDirection() - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.DESC_NULLS_LAST - : Expression.SortDirection.DESC_NULLS_FIRST; - default -> throw new IllegalArgumentException( - String.format( - "Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!", - rexDirection)); - }; + Expression.SortDirection direction = asSortDirection(rexFieldCollation); return Expression.SortField.builder().expr(expr).direction(direction).build(); } + + private static Expression.SortDirection asSortDirection(RexFieldCollation collation) { + RelFieldCollation.Direction direction = collation.getDirection(); + + if (direction == Direction.ASCENDING) { + return collation.getNullDirection() == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.ASC_NULLS_LAST + : Expression.SortDirection.ASC_NULLS_FIRST; + } + if (direction == Direction.DESCENDING) { + return collation.getNullDirection() == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.DESC_NULLS_LAST + : Expression.SortDirection.DESC_NULLS_FIRST; + } + + throw new IllegalArgumentException( + String.format( + "Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!", direction)); + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java index 3208905ca..ad773136e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java @@ -16,18 +16,25 @@ public static WindowBound toWindowBound(RexWindowBound rexWindowBound) { if (rexWindowBound.isUnbounded()) { return WindowBound.UNBOUNDED; } else { - if (rexWindowBound.getOffset() instanceof RexLiteral literal - && SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { - BigDecimal offset = (BigDecimal) literal.getValue4(); - if (rexWindowBound.isPreceding()) { - return WindowBound.Preceding.of(offset.longValue()); - } - if (rexWindowBound.isFollowing()) { - return WindowBound.Following.of(offset.longValue()); + var node = rexWindowBound.getOffset(); + + if (node instanceof RexLiteral) { + var literal = (RexLiteral) node; + if (SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { + BigDecimal offset = (BigDecimal) literal.getValue4(); + + if (rexWindowBound.isPreceding()) { + return WindowBound.Preceding.of(offset.longValue()); + } + if (rexWindowBound.isFollowing()) { + return WindowBound.Following.of(offset.longValue()); + } + + throw new IllegalStateException( + "window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING"); } - throw new IllegalStateException( - "window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING"); } + throw new IllegalArgumentException( String.format( "substrait only supports integer window offsets. Received: %s", diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java index a1e9bff33..1d988495c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java @@ -99,7 +99,7 @@ public Optional convert( return m.attemptMatch(wrapped, topLevelConverter); } - static class WrappedWindowRelCall implements GenericCall { + static class WrappedWindowRelCall implements FunctionConverter.GenericCall { private final Window.RexWinAggCall winAggCall; private final RexWindowBound lowerBound; private final RexWindowBound upperBound; diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java index 83967195a..3751c0880 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java @@ -67,10 +67,12 @@ public static List processCreateStatements(String createStatemen SqlNodeList sqlNode = parser.parseStmtList(); for (SqlNode parsed : sqlNode) { - if (!(parsed instanceof SqlCreateTable create)) { + if (!(parsed instanceof SqlCreateTable)) { throw fail("Not a valid CREATE TABLE statement."); } + SqlCreateTable create = (SqlCreateTable) parsed; + if (create.name.names.size() > 1) { throw fail("Only simple table names are allowed.", create.name.getParserPosition()); } @@ -83,7 +85,7 @@ public static List processCreateStatements(String createStatemen List columnTypes = new ArrayList<>(); for (SqlNode node : create.columnList) { - if (!(node instanceof SqlColumnDeclaration col)) { + if (!(node instanceof SqlColumnDeclaration)) { if (node instanceof SqlKeyConstraint) { // key constraints declarations, like primary key declaration, are valid and should not // result in parse exceptions. Ignore the constraint declaration. @@ -93,6 +95,8 @@ public static List processCreateStatements(String createStatemen throw fail("Unexpected column list construction.", node.getParserPosition()); } + SqlColumnDeclaration col = (SqlColumnDeclaration) node; + if (col.name.names.size() != 1) { throw fail("Expected simple column names.", col.name.getParserPosition()); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java index 62357aa3d..32d4b2553 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java @@ -41,15 +41,20 @@ public class AggregationFunctionsTest extends PlanTestBase { // Create the given function call on the given field of the input private Aggregate.Measure functionPicker(Rel input, int field, String fname) { - return switch (fname) { - case "min" -> b.min(input, field); - case "max" -> b.max(input, field); - case "sum" -> b.sum(input, field); - case "sum0" -> b.sum0(input, field); - case "avg" -> b.avg(input, field); - default -> throw new RuntimeException( - String.format("no function is associated with %s", fname)); - }; + switch (fname) { + case "min": + return b.min(input, field); + case "max": + return b.max(input, field); + case "sum": + return b.sum(input, field); + case "sum0": + return b.sum0(input, field); + case "avg": + return b.avg(input, field); + default: + throw new RuntimeException(String.format("no function is associated with %s", fname)); + } } // Create one function call per numeric type column diff --git a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java index e8036a248..5c3f963bf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java @@ -37,10 +37,9 @@ private static Map buildOuterFieldRefMap(RelRoot root) public void lateralJoinQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales CROSS JOIN LATERAL - (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"""; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales CROSS JOIN LATERAL\n" + + " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"; /* the calcite plan for the above query is: LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3]) @@ -69,11 +68,9 @@ public void lateralJoinQuery() throws SqlParseException { public void outerApplyQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales OUTER APPLY - (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk) - """; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales OUTER APPLY\n" + + " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"; RelRoot root = getCalcitePlan(sql); @@ -92,15 +89,14 @@ public void outerApplyQuery() throws SqlParseException { public void nestedApplyJoinQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales CROSS APPLY - ( SELECT i_item_sk - FROM item CROSS APPLY - ( SELECT p_promo_sk - FROM promotion - WHERE p_item_sk = i_item_sk AND p_item_sk = ss_item_sk ) - WHERE item.i_item_sk = store_sales.ss_item_sk )"""; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales CROSS APPLY\n" + + " ( SELECT i_item_sk\n" + + " FROM item CROSS APPLY\n" + + " ( SELECT p_promo_sk\n" + + " FROM promotion\n" + + " WHERE p_item_sk = i_item_sk AND p_item_sk = ss_item_sk )\n" + + " WHERE item.i_item_sk = store_sales.ss_item_sk )"; /* the calcite plan for the above query is: LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3]) @@ -135,10 +131,9 @@ public void nestedApplyJoinQuery() throws SqlParseException { public void crossApplyQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales CROSS APPLY - (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"""; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales CROSS APPLY\n" + + " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"; FeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); SqlToSubstrait s = new SqlToSubstrait(featureBoard); diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java index fc3c4559b..775e1bfe5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java @@ -14,12 +14,16 @@ public abstract class CalciteObjs { final RexBuilder rex = new RexBuilder(type); RelDataType t(SqlTypeName typeName, int... vals) { - return switch (vals.length) { - case 0 -> type.createSqlType(typeName); - case 1 -> type.createSqlType(typeName, vals[0]); - case 2 -> type.createSqlType(typeName, vals[0], vals[1]); - default -> throw new IllegalArgumentException(); - }; + switch (vals.length) { + case 0: + return type.createSqlType(typeName); + case 1: + return type.createSqlType(typeName, vals[0]); + case 2: + return type.createSqlType(typeName, vals[0], vals[1]); + default: + throw new IllegalArgumentException(); + } } RelDataType tN(SqlTypeName typeName, int... vals) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java index 4972fce50..3f7407530 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -65,11 +65,9 @@ void handleInputReferenceSort() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - Collation: [0] - LogicalSort(sort0=[$0], dir0=[ASC]) - LogicalTableScan(table=[[example]]) - """; + "Collation: [0]\n" + + "LogicalSort(sort0=[$0], dir0=[ASC])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -95,13 +93,11 @@ void handleCastExpressionSort() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - LogicalProject(a0=[$0]) - Collation: [1] - LogicalSort(sort0=[$1], dir0=[ASC]) - LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[$0])\n" + + " Collation: [1]\n" + + " LogicalSort(sort0=[$1], dir0=[ASC])\n" + + " LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -127,13 +123,11 @@ void handleCastProjectAndSortWithSortDirection() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - LogicalProject(a0=[CAST($0):INTEGER NOT NULL]) - Collation: [1 DESC-nulls-last] - LogicalSort(sort0=[$1], dir0=[DESC-nulls-last]) - LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[CAST($0):INTEGER NOT NULL])\n" + + " Collation: [1 DESC-nulls-last]\n" + + " LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])\n" + + " LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -159,13 +153,11 @@ void handleCastSortToOriginalType() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - LogicalProject(a0=[$0]) - Collation: [1 DESC-nulls-last] - LogicalSort(sort0=[$1], dir0=[DESC-nulls-last]) - LogicalProject(a=[$0], a0=[$0]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[$0])\n" + + " Collation: [1 DESC-nulls-last]\n" + + " LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])\n" + + " LogicalProject(a=[$0], a0=[$0])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -194,13 +186,11 @@ void handleComplex2ExpressionSort() { b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.STRING, R.I32)))); String expected = - """ - LogicalProject(a0=[$0], b0=[$1]) - Collation: [2 DESC, 3] - LogicalSort(sort0=[$2], sort1=[$3], dir0=[DESC], dir1=[ASC]) - LogicalProject(a=[$0], b=[$1], a0=[CAST($0):INTEGER NOT NULL], $f3=[+(-($1), 42)]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[$0], b0=[$1])\n" + + " Collation: [2 DESC, 3]\n" + + " LogicalSort(sort0=[$2], sort1=[$3], dir0=[DESC], dir1=[ASC])\n" + + " LogicalProject(a=[$0], b=[$1], a0=[CAST($0):INTEGER NOT NULL], $f3=[+(-($1), 42)])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index bb4dcd746..e4e1c4920 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -21,9 +21,7 @@ void preserveNamesFromSql() throws Exception { SqlToSubstrait s = new SqlToSubstrait(); var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); - String query = """ - SELECT "a", "B" FROM foo GROUP BY a, b - """; + String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; List expectedNames = List.of("a", "B"); List calciteRelRoots = s.sqlToRelNode(query, catalogReader); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java index 0b437d789..ea7603a1f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java @@ -85,24 +85,17 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a" - FROM - "nested"."my_table"; - """; + "SELECT\n" + " \"nested\".\"my_table\".\"a\"\n" + "FROM\n" + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -121,29 +114,25 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"."b" - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\".\"b\"\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - struct_field { - field: 0 # b - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # b\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -162,34 +151,30 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"."b"."c" - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\".\"b\".\"c\"\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - struct_field { - field: 0 # b - child: { - struct_field { - field: 0 # c - } - } - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # b\n" + + " child: {\n" + + " struct_field {\n" + + " field: 0 # c\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -207,29 +192,25 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"[1] - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\"[1]\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - list_element { - offset: 1 - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " list_element {\n" + + " offset: 1\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -251,39 +232,35 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"[1][2][3] - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\"[1][2][3]\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - list_element { - offset: 1 - child { - list_element { - offset: 2 - child { - list_element { - offset: 3 - } - } - } - } - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " list_element {\n" + + " offset: 1\n" + + " child {\n" + + " list_element {\n" + + " offset: 2\n" + + " child {\n" + + " list_element {\n" + + " offset: 3\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -308,51 +285,47 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table".a.b[2].c['my_map_key'].x - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".a.b[2].c['my_map_key'].x\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 0 # .a - child { - struct_field { - field: 0 # .b - child { - list_element { - offset: 2 - child { - struct_field { - field: 0 # .c - child { - map_key { - map_key { - string: "my_map_key" # ['my_map_key'] - } - child { - struct_field { - field: 0 # .x - } - } - } - } - } - } - } - } - } - } - } - } - root_reference {} - } - """; + " selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 0 # .a\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # .b\n" + + " child {\n" + + " list_element {\n" + + " offset: 2\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # .c\n" + + " child {\n" + + " map_key {\n" + + " map_key {\n" + + " string: \"my_map_key\" # ['my_map_key']\n" + + " }\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # .x\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference {}\n" + + "}\n"; test(table, query, expectedExpressionText); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index 48d13612f..0dca1f7d1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -76,34 +76,28 @@ public Optional visit(Cross cross, EmptyVisitationContext context) }; var featureBoard = ImmutableFeatureBoard.builder().build(); - Plan plan1 = - assertProtoPlanRoundrip( - """ - select - c.c_custKey, - o.o_custkey - from - "customer" c cross join - "orders" o - """, - new SqlToSubstrait(featureBoard)); + String query1 = + "select\n" + + " c.c_custKey,\n" + + " o.o_custkey\n" + + "from\n" + + " \"customer\" c cross join\n" + + " \"orders\" o"; + Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait(featureBoard)); plan1 .getRoots() .forEach( t -> t.getInput().accept(crossJoinCountingVisitor, EmptyVisitationContext.INSTANCE)); assertEquals(1, counter[0]); - Plan plan2 = - assertProtoPlanRoundrip( - """ - select - c.c_custKey, - o.o_custkey - from - "customer" c, - "orders" o - """, - new SqlToSubstrait(featureBoard)); + String query2 = + "select\n" + + " c.c_custKey,\n" + + " o.o_custkey\n" + + "from\n" + + " \"customer\" c,\n" + + " \"orders\" o"; + Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait(featureBoard)); plan2 .getRoots() .forEach( diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java b/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java index deac90e87..3e57905e2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java @@ -18,17 +18,7 @@ private SetUtils() {} * @return a sql query */ public static String getSetQuery(Set.SetOp op, boolean multi) { - String opString = - switch (op) { - case MINUS_PRIMARY -> "EXCEPT"; - case MINUS_PRIMARY_ALL -> "EXCEPT ALL"; - case INTERSECTION_MULTISET -> "INTERSECT"; - case INTERSECTION_MULTISET_ALL -> "INTERSECT ALL"; - case UNION_DISTINCT -> "UNION"; - case UNION_ALL -> "UNION ALL"; - default -> throw new UnsupportedOperationException( - "Unknown set operation is not supported"); - }; + String opString = asString(op); StringBuilder query = new StringBuilder(); query.append( @@ -49,6 +39,25 @@ public static String getSetQuery(Set.SetOp op, boolean multi) { } } + private static String asString(Set.SetOp op) { + switch (op) { + case MINUS_PRIMARY: + return "EXCEPT"; + case MINUS_PRIMARY_ALL: + return "EXCEPT ALL"; + case INTERSECTION_MULTISET: + return "INTERSECT"; + case INTERSECTION_MULTISET_ALL: + return "INTERSECT ALL"; + case UNION_DISTINCT: + return "UNION"; + case UNION_ALL: + return "UNION ALL"; + default: + throw new UnsupportedOperationException("Unknown set operation is not supported"); + } + } + // Generate all SetOp types excluding: // * MINUS_MULTISET, INTERSECTION_PRIMARY: do not map to Calcite relations // * UNKNOWN: invalid diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts index 0acec7036..391470c99 100644 --- a/spark/build.gradle.kts +++ b/spark/build.gradle.kts @@ -96,7 +96,7 @@ configurations.all { } java { - toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } + toolchain { languageVersion = JavaLanguageVersion.of(17) } withJavadocJar() withSourcesJar() } @@ -150,5 +150,6 @@ tasks { test { dependsOn(":core:shadowJar") useJUnitPlatform { includeEngines("scalatest") } + jvmArgs("--add-exports=java.base/sun.nio.ch=ALL-UNNAMED") } } From a435b669a7b689c6d3b6908478bd4718970aa25e Mon Sep 17 00:00:00 2001 From: "Mark S. Lewis" Date: Thu, 17 Jul 2025 10:57:33 +0100 Subject: [PATCH 2/2] chore: address review comments Signed-off-by: Mark S. Lewis --- .../main/java/io/substrait/relation/Join.java | 18 ++++++++---------- .../substrait/relation/files/FileFormat.java | 12 ++++++------ .../relation/ProtoRelConverterTest.java | 10 +++++----- spark/build.gradle.kts | 5 +---- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index 205e51572..adb9cd535 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -76,12 +76,11 @@ private Stream getLeftTypes() { return getLeft().getRecordType().fields().stream().map(TypeCreator::asNullable); case RIGHT_SEMI: case RIGHT_ANTI: - return Stream.of(); // these are right joins which ignore left side columns + // these are right joins which ignore left side columns + return Stream.of(); case RIGHT_MARK: - return Stream.of( - TypeCreator.REQUIRED - .BOOLEAN); // right mark join keeps all fields from right and adds a boolean mark - // field + // right mark join keeps all fields from right and adds a boolean mark field + return Stream.of(TypeCreator.REQUIRED.BOOLEAN); default: return getLeft().getRecordType().fields().stream(); } @@ -97,12 +96,11 @@ private Stream getRightTypes() { case ANTI: case LEFT_SEMI: case LEFT_ANTI: - return Stream.of(); // these are left joins which ignore right side columns + // these are left joins which ignore right side columns + return Stream.of(); case LEFT_MARK: - return Stream.of( - TypeCreator.REQUIRED - .BOOLEAN); // left mark join keeps all fields from left and adds a boolean mark - // field + // left mark join keeps all fields from left and adds a boolean mark field + return Stream.of(TypeCreator.REQUIRED.BOOLEAN); default: return getRight().getRecordType().fields().stream(); } diff --git a/core/src/main/java/io/substrait/relation/files/FileFormat.java b/core/src/main/java/io/substrait/relation/files/FileFormat.java index ed9482d3b..415ff714e 100644 --- a/core/src/main/java/io/substrait/relation/files/FileFormat.java +++ b/core/src/main/java/io/substrait/relation/files/FileFormat.java @@ -7,35 +7,35 @@ public interface FileFormat { @Value.Immutable - abstract static class ParquetReadOptions implements FileFormat { + abstract class ParquetReadOptions implements FileFormat { public static ImmutableFileFormat.ParquetReadOptions.Builder builder() { return ImmutableFileFormat.ParquetReadOptions.builder(); } } @Value.Immutable - abstract static class ArrowReadOptions implements FileFormat { + abstract class ArrowReadOptions implements FileFormat { public static ImmutableFileFormat.ArrowReadOptions.Builder builder() { return ImmutableFileFormat.ArrowReadOptions.builder(); } } @Value.Immutable - abstract static class OrcReadOptions implements FileFormat { + abstract class OrcReadOptions implements FileFormat { public static ImmutableFileFormat.OrcReadOptions.Builder builder() { return ImmutableFileFormat.OrcReadOptions.builder(); } } @Value.Immutable - abstract static class DwrfReadOptions implements FileFormat { + abstract class DwrfReadOptions implements FileFormat { public static ImmutableFileFormat.DwrfReadOptions.Builder builder() { return ImmutableFileFormat.DwrfReadOptions.builder(); } } @Value.Immutable - abstract static class DelimiterSeparatedTextReadOptions implements FileFormat { + abstract class DelimiterSeparatedTextReadOptions implements FileFormat { public abstract String getFieldDelimiter(); public abstract long getMaxLineSize(); @@ -54,7 +54,7 @@ public static ImmutableFileFormat.DelimiterSeparatedTextReadOptions.Builder buil } @Value.Immutable - abstract static class Extension implements FileFormat { + abstract class Extension implements FileFormat { public abstract com.google.protobuf.Any getExtension(); public static ImmutableFileFormat.Extension.Builder builder() { diff --git a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java index b8dd03465..eeca3b134 100644 --- a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java +++ b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java @@ -25,8 +25,8 @@ public class ProtoRelConverterTest extends TestBase { @Nested class DefaultAdvancedExtensionTests { - final StringHolder ENHANCED = new StringHolder("ENHANCED"); - final StringHolder OPTIMIZED = new StringHolder("OPTIMIZED"); + final StringHolder enhanced = new StringHolder("ENHANCED"); + final StringHolder optimized = new StringHolder("OPTIMIZED"); Rel relWithExtension(AdvancedExtension advancedExtension) { return NamedScan.builder() @@ -38,12 +38,12 @@ Rel relWithExtension(AdvancedExtension advancedExtension) { Rel emptyAdvancedExtension = relWithExtension(AdvancedExtension.builder().build()); Rel advancedExtensionWithOptimization = - relWithExtension(AdvancedExtension.builder().addOptimizations(OPTIMIZED).build()); + relWithExtension(AdvancedExtension.builder().addOptimizations(optimized).build()); Rel advancedExtensionWithEnhancement = - relWithExtension(AdvancedExtension.builder().enhancement(ENHANCED).build()); + relWithExtension(AdvancedExtension.builder().enhancement(enhanced).build()); Rel advancedExtensionWithEnhancementAndOptimization = relWithExtension( - AdvancedExtension.builder().enhancement(ENHANCED).addOptimizations(OPTIMIZED).build()); + AdvancedExtension.builder().enhancement(enhanced).addOptimizations(optimized).build()); @Test void emptyAdvancedExtension() { diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts index 391470c99..6af7a3397 100644 --- a/spark/build.gradle.kts +++ b/spark/build.gradle.kts @@ -101,10 +101,7 @@ java { withSourcesJar() } -tasks.withType() { - targetCompatibility = "" - scalaCompileOptions.additionalParameters = listOf("-release:17") -} +tasks.withType() { scalaCompileOptions.additionalParameters = listOf("-release:17") } var SLF4J_VERSION = properties.get("slf4j.version") var SPARKBUNDLE_VERSION = properties.get("sparkbundle.version")