diff --git a/build.gradle.kts b/build.gradle.kts index 667ffb892..ec1b80873 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -26,8 +26,6 @@ dependencies { implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}") - annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } val submodulesUpdate by @@ -41,20 +39,10 @@ allprojects { repositories { mavenCentral() } tasks.configureEach { - val javaToolchains = project.extensions.getByType() useJUnitPlatform() - javaLauncher.set(javaToolchains.launcherFor { languageVersion.set(JavaLanguageVersion.of(11)) }) testLogging { exceptionFormat = TestExceptionFormat.FULL } } - tasks.withType { - sourceCompatibility = "17" - if (project.name != "core") { - options.release.set(11) - } else { - options.release.set(8) - } - dependsOn(submodulesUpdate) - } + tasks.withType { dependsOn(submodulesUpdate) } group = "io.substrait" version = "${version}" diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 97b30a3f1..150bf8b2b 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -128,8 +128,6 @@ dependencies { implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}") compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}") - annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } configurations[JavaPlugin.API_CONFIGURATION_NAME].let { apiConfiguration -> @@ -233,12 +231,12 @@ tasks { jar { manifest { from("build/generated/sources/manifest/META-INF/MANIFEST.MF") } } } +// Set the release instead of using a Java 8 toolchain since ANTLR requires Java 11+ to run +tasks.withType().configureEach { options.release = 8 } + java { - toolchain { - languageVersion.set(JavaLanguageVersion.of(17)) - withJavadocJar() - withSourcesJar() - } + withJavadocJar() + withSourcesJar() } configurations { runtimeClasspath { resolutionStrategy.activateDependencyLocking() } } diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index f2985ff4a..9de3a0faf 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -1,6 +1,5 @@ package io.substrait.dsl; -import com.github.bsideup.jabel.Desugar; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.Expression.Cast; @@ -87,8 +86,8 @@ private Aggregate aggregate( Function> measuresFn, Optional remap, Rel input) { - var groupings = groupingsFn.apply(input); - var measures = measuresFn.apply(input); + List groupings = groupingsFn.apply(input); + List measures = measuresFn.apply(input); return Aggregate.builder() .groupings(groupings) .measures(measures) @@ -147,12 +146,27 @@ public Filter filter(Function conditionFn, Rel.Remap remap, Rel private Filter filter( Function conditionFn, Optional remap, Rel input) { - var condition = conditionFn.apply(input); + Expression condition = conditionFn.apply(input); return Filter.builder().input(input).condition(condition).remap(remap).build(); } - @Desugar - public record JoinInput(Rel left, Rel right) {} + public static final class JoinInput { + private final Rel left; + private final Rel right; + + JoinInput(Rel left, Rel right) { + this.left = left; + this.right = right; + } + + public Rel left() { + return left; + } + + public Rel right() { + return right; + } + } public Join innerJoin(Function conditionFn, Rel left, Rel right) { return join(conditionFn, Join.JoinType.INNER, left, right); @@ -183,7 +197,7 @@ private Join join( Optional remap, Rel left, Rel right) { - var condition = conditionFn.apply(new JoinInput(left, right)); + Expression condition = conditionFn.apply(new JoinInput(left, right)); return Join.builder() .left(left) .right(right) @@ -263,7 +277,7 @@ private NestedLoopJoin nestedLoopJoin( Optional remap, Rel left, Rel right) { - var condition = conditionFn.apply(new JoinInput(left, right)); + Expression condition = conditionFn.apply(new JoinInput(left, right)); return NestedLoopJoin.builder() .left(left) .right(right) @@ -291,8 +305,8 @@ private NamedScan namedScan( Iterable columnNames, Iterable types, Optional remap) { - var struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); - var namedStruct = NamedStruct.of(columnNames, struct); + Type.Struct struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); + NamedStruct namedStruct = NamedStruct.of(columnNames, struct); return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); } @@ -315,7 +329,7 @@ private Project project( Function> expressionsFn, Optional remap, Rel input) { - var expressions = expressionsFn.apply(input); + Iterable expressions = expressionsFn.apply(input); return Project.builder().input(input).expressions(expressions).remap(remap).build(); } @@ -332,7 +346,7 @@ private Expand expand( Function> fieldsFn, Optional remap, Rel input) { - var fields = fieldsFn.apply(input); + Iterable fields = fieldsFn.apply(input); return Expand.builder().input(input).fields(fields).remap(remap).build(); } @@ -363,7 +377,7 @@ private Sort sort( Function> sortFieldFn, Optional remap, Rel input) { - var condition = sortFieldFn.apply(input); + Iterable condition = sortFieldFn.apply(input); return Sort.builder().input(input).sortFields(condition).remap(remap).build(); } @@ -465,7 +479,7 @@ public Switch switchExpression( public AggregateFunctionInvocation aggregateFn( String namespace, String key, Type outputType, Expression... args) { - var declaration = + SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(namespace, key)); return AggregateFunctionInvocation.builder() .arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList())) @@ -477,7 +491,7 @@ public AggregateFunctionInvocation aggregateFn( } public Aggregate.Grouping grouping(Rel input, int... indexes) { - var columns = fieldReferences(input, indexes); + List columns = fieldReferences(input, indexes); return Aggregate.Grouping.builder().addAllExpressions(columns).build(); } @@ -486,7 +500,7 @@ public Aggregate.Grouping grouping(Expression... expressions) { } public Aggregate.Measure count(Rel input, int field) { - var declaration = + SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any")); @@ -563,7 +577,7 @@ public Aggregate.Measure sum0(Expression expr) { private Aggregate.Measure singleArgumentArithmeticAggregate( Expression expr, String functionName, Type outputType) { String typeString = ToTypeString.apply(expr.getType()); - var declaration = + SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, @@ -585,7 +599,7 @@ private Aggregate.Measure singleArgumentArithmeticAggregate( public Expression.ScalarFunctionInvocation negate(Expression expr) { // output type of negate is the same as the input type - var outputType = expr.getType(); + Type outputType = expr.getType(); return scalarFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, String.format("negate:%s", ToTypeString.apply(outputType)), @@ -611,12 +625,12 @@ public Expression.ScalarFunctionInvocation divide(Expression left, Expression ri private Expression.ScalarFunctionInvocation arithmeticFunction( String fname, Expression left, Expression right) { - var leftTypeStr = ToTypeString.apply(left.getType()); - var rightTypeStr = ToTypeString.apply(right.getType()); - var key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr); + String leftTypeStr = ToTypeString.apply(left.getType()); + String rightTypeStr = ToTypeString.apply(right.getType()); + String key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr); - var isOutputNullable = left.getType().nullable() || right.getType().nullable(); - var outputType = left.getType(); + boolean isOutputNullable = left.getType().nullable() || right.getType().nullable(); + Type outputType = left.getType(); outputType = isOutputNullable ? TypeCreator.asNullable(outputType) @@ -633,14 +647,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig public Expression.ScalarFunctionInvocation or(Expression... args) { // If any arg is nullable, the output of or is potentially nullable // For example: false or null = null - var isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); - var outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; + boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); + Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args); } public Expression.ScalarFunctionInvocation scalarFn( String namespace, String key, Type outputType, FunctionArg... args) { - var declaration = + SimpleExtension.ScalarFunctionVariant declaration = extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, key)); return Expression.ScalarFunctionInvocation.builder() .declaration(declaration) @@ -659,7 +673,7 @@ public Expression.WindowFunctionInvocation windowFn( WindowBound lowerBound, WindowBound upperBound, Expression... args) { - var declaration = + SimpleExtension.WindowFunctionVariant declaration = extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(namespace, key)); return Expression.WindowFunctionInvocation.builder() .declaration(declaration) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 75e003f53..913a80994 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -833,7 +833,7 @@ public io.substrait.proto.Expression.WindowFunction.BoundsType toProto() { public static WindowBoundsType fromProto( io.substrait.proto.Expression.WindowFunction.BoundsType proto) { - for (var v : values()) { + for (WindowBoundsType v : values()) { if (v.proto == proto) { return v; } @@ -984,7 +984,7 @@ public io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp toProto() public static PredicateOp fromProto( io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto) { - for (var v : values()) { + for (PredicateOp v : values()) { if (v.proto == proto) { return v; } @@ -1010,7 +1010,7 @@ public io.substrait.proto.AggregateFunction.AggregationInvocation toProto() { } public static AggregationInvocation fromProto(AggregateFunction.AggregationInvocation proto) { - for (var v : values()) { + for (AggregationInvocation v : values()) { if (v.proto == proto) { return v; } @@ -1041,7 +1041,7 @@ public io.substrait.proto.AggregationPhase toProto() { } public static AggregationPhase fromProto(io.substrait.proto.AggregationPhase proto) { - for (var v : values()) { + for (AggregationPhase v : values()) { if (v.proto == proto) { return v; } @@ -1069,7 +1069,7 @@ public io.substrait.proto.SortField.SortDirection toProto() { } public static SortDirection fromProto(io.substrait.proto.SortField.SortDirection proto) { - for (var v : values()) { + for (SortDirection v : values()) { if (v.proto == proto) { return v; } @@ -1097,7 +1097,7 @@ public io.substrait.proto.Expression.Cast.FailureBehavior toProto() { public static FailureBehavior fromProto( io.substrait.proto.Expression.Cast.FailureBehavior proto) { - for (var v : values()) { + for (FailureBehavior v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 55e71b78d..6480c5bec 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -7,6 +7,7 @@ import io.substrait.type.Type; import io.substrait.util.DecimalUtil; import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneOffset; @@ -89,7 +90,7 @@ public static Expression.TimestampLiteral timestamp(boolean nullable, long value */ @Deprecated public static Expression.TimestampLiteral timestamp(boolean nullable, LocalDateTime value) { - var epochMicro = + long epochMicro = TimeUnit.SECONDS.toMicros(value.toEpochSecond(ZoneOffset.UTC)) + TimeUnit.NANOSECONDS.toMicros(value.toLocalTime().getNano()); return timestamp(nullable, epochMicro); @@ -127,7 +128,7 @@ public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, long v */ @Deprecated public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, Instant value) { - var epochMicro = + long epochMicro = TimeUnit.SECONDS.toMicros(value.getEpochSecond()) + TimeUnit.NANOSECONDS.toMicros(value.getNano()); return timestampTZ(nullable, epochMicro); @@ -195,7 +196,7 @@ public static Expression.IntervalCompoundLiteral intervalCompound( } public static Expression.UUIDLiteral uuid(boolean nullable, ByteString uuid) { - var bb = uuid.asReadOnlyByteBuffer(); + ByteBuffer bb = uuid.asReadOnlyByteBuffer(); return Expression.UUIDLiteral.builder() .nullable(nullable) .value(new UUID(bb.getLong(), bb.getLong())) @@ -237,7 +238,7 @@ public static Expression.DecimalLiteral decimal( public static Expression.DecimalLiteral decimal( boolean nullable, BigDecimal value, int precision, int scale) { - var twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16); + byte[] twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16); return Expression.DecimalLiteral.builder() .nullable(nullable) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index f2926f473..b5c42fdd7 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -225,7 +225,7 @@ private static FieldReference of( Collections.reverse(segments); for (int i = 0; i < segments.size(); i++) { if (i == 0) { - var last = segments.get(0); + ReferenceSegment last = segments.get(0); reference = struct == null ? last.constructOnExpression(expression) : last.constructOnRoot(struct); } else { diff --git a/core/src/main/java/io/substrait/expression/FunctionArg.java b/core/src/main/java/io/substrait/expression/FunctionArg.java index 495def8ad..95a437fd5 100644 --- a/core/src/main/java/io/substrait/expression/FunctionArg.java +++ b/core/src/main/java/io/substrait/expression/FunctionArg.java @@ -34,13 +34,13 @@ static FuncArgVisitor expressionVisitor) { - return new FuncArgVisitor<>() { + return new FuncArgVisitor() { @Override public FunctionArgument visitExpr( SimpleExtension.Function fnDef, int argIdx, Expression e, EmptyVisitationContext context) throws RuntimeException { - var pE = e.accept(expressionVisitor, context); + io.substrait.proto.Expression pE = e.accept(expressionVisitor, context); return FunctionArgument.newBuilder().setValue(pE).build(); } @@ -48,7 +48,7 @@ public FunctionArgument visitExpr( public FunctionArgument visitType( SimpleExtension.Function fnDef, int argIdx, Type t, EmptyVisitationContext context) throws RuntimeException { - var pTyp = t.accept(typeVisitor); + io.substrait.proto.Type pTyp = t.accept(typeVisitor); return FunctionArgument.newBuilder().setType(pTyp).build(); } @@ -56,7 +56,7 @@ public FunctionArgument visitType( public FunctionArgument visitEnumArg( SimpleExtension.Function fnDef, int argIdx, EnumArg ea, EmptyVisitationContext context) throws RuntimeException { - var enumBldr = FunctionArgument.newBuilder(); + FunctionArgument.Builder enumBldr = FunctionArgument.newBuilder(); if (ea.value().isPresent()) { enumBldr = enumBldr.setEnum(ea.value().get()); @@ -82,18 +82,22 @@ public ProtoFrom( public FunctionArg convert( SimpleExtension.Function funcDef, int argIdx, FunctionArgument fArg) { - return switch (fArg.getArgTypeCase()) { - case TYPE -> protoTypeConverter.from(fArg.getType()); - case VALUE -> protoExprConverter.from(fArg.getValue()); - case ENUM -> { - SimpleExtension.EnumArgument enumArgDef = - (SimpleExtension.EnumArgument) funcDef.args().get(argIdx); - var optionValue = fArg.getEnum(); - yield EnumArg.of(enumArgDef, optionValue); - } - default -> throw new UnsupportedOperationException( - String.format("Unable to convert FunctionArgument %s.", fArg)); - }; + switch (fArg.getArgTypeCase()) { + case TYPE: + return protoTypeConverter.from(fArg.getType()); + case VALUE: + return protoExprConverter.from(fArg.getValue()); + case ENUM: + { + SimpleExtension.EnumArgument enumArgDef = + (SimpleExtension.EnumArgument) funcDef.args().get(argIdx); + String optionValue = fArg.getEnum(); + return EnumArg.of(enumArgDef, optionValue); + } + default: + throw new UnsupportedOperationException( + String.format("Unable to convert FunctionArgument %s.", fArg)); + } } } } diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 146243e18..0e7593d27 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -70,7 +70,7 @@ public Expression visit( } private Expression lit(Consumer consumer) { - var builder = Expression.Literal.newBuilder(); + Expression.Literal.Builder builder = Expression.Literal.newBuilder(); consumer.accept(builder); return Expression.newBuilder().setLiteral(builder).build(); } @@ -278,12 +278,12 @@ public Expression visit( io.substrait.expression.Expression.MapLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var keyValues = + List keyValues = expr.values().entrySet().stream() .map( e -> { - var key = toLiteral(e.getKey()); - var value = toLiteral(e.getValue()); + Expression.Literal key = toLiteral(e.getKey()); + Expression.Literal value = toLiteral(e.getValue()); return Expression.Literal.Map.KeyValue.newBuilder() .setKey(key) .setValue(value) @@ -300,7 +300,7 @@ public Expression visit( io.substrait.expression.Expression.EmptyMapLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var protoMapType = toProto(expr.getType()); + Type protoMapType = toProto(expr.getType()); bldr.setEmptyMap(protoMapType.getMap()) // For empty maps, the Literal message's own nullable field should be ignored // in favor of the nullability of the Type.Map in the literal's @@ -316,7 +316,7 @@ public Expression visit( io.substrait.expression.Expression.ListLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var values = + List values = expr.values().stream() .map(this::toLiteral) .collect(java.util.stream.Collectors.toList()); @@ -331,7 +331,7 @@ public Expression visit( throws RuntimeException { return lit( builder -> { - var protoListType = toProto(expr.getType()); + Type protoListType = toProto(expr.getType()); builder .setEmptyList(protoListType.getList()) // For empty lists, the Literal message's own nullable field should be ignored @@ -348,7 +348,7 @@ public Expression visit( io.substrait.expression.Expression.StructLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { - var values = + List values = expr.fields().stream() .map(this::toLiteral) .collect(java.util.stream.Collectors.toList()); @@ -360,7 +360,7 @@ public Expression visit( @Override public Expression visit( io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) { - var typeReference = + int typeReference = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); return lit( bldr -> { @@ -378,7 +378,7 @@ public Expression visit( } private Expression.Literal toLiteral(io.substrait.expression.Expression expression) { - var e = toProto(expression); + Expression e = toProto(expression); assert e.getRexTypeCase() == Expression.RexTypeCase.LITERAL; return e.getLiteral(); } @@ -386,7 +386,7 @@ private Expression.Literal toLiteral(io.substrait.expression.Expression expressi @Override public Expression visit( io.substrait.expression.Expression.Switch expr, EmptyVisitationContext context) { - var clauses = + List clauses = expr.switchClauses().stream() .map( s -> @@ -407,7 +407,7 @@ public Expression visit( @Override public Expression visit( io.substrait.expression.Expression.IfThen expr, EmptyVisitationContext context) { - var clauses = + List clauses = expr.ifClauses().stream() .map( s -> @@ -427,7 +427,8 @@ public Expression visit( io.substrait.expression.Expression.ScalarFunctionInvocation expr, EmptyVisitationContext context) { - var argVisitor = FunctionArg.toProto(typeProtoConverter, this); + FunctionArg.FuncArgVisitor + argVisitor = FunctionArg.toProto(typeProtoConverter, this); return Expression.newBuilder() .setScalarFunction( @@ -499,22 +500,28 @@ public Expression visit( public Expression visit(FieldReference expr, EmptyVisitationContext context) { Expression.ReferenceSegment seg = null; - for (var segment : expr.segments()) { + for (FieldReference.ReferenceSegment segment : expr.segments()) { Expression.ReferenceSegment.Builder protoSegment; - if (segment instanceof FieldReference.StructField f) { - var bldr = Expression.ReferenceSegment.StructField.newBuilder().setField(f.offset()); + if (segment instanceof FieldReference.StructField) { + FieldReference.StructField f = (FieldReference.StructField) segment; + Expression.ReferenceSegment.StructField.Builder bldr = + Expression.ReferenceSegment.StructField.newBuilder().setField(f.offset()); if (seg != null) { bldr.setChild(seg); } protoSegment = Expression.ReferenceSegment.newBuilder().setStructField(bldr); - } else if (segment instanceof FieldReference.ListElement f) { - var bldr = Expression.ReferenceSegment.ListElement.newBuilder().setOffset(f.offset()); + } else if (segment instanceof FieldReference.ListElement) { + FieldReference.ListElement f = (FieldReference.ListElement) segment; + Expression.ReferenceSegment.ListElement.Builder bldr = + Expression.ReferenceSegment.ListElement.newBuilder().setOffset(f.offset()); if (seg != null) { bldr.setChild(seg); } protoSegment = Expression.ReferenceSegment.newBuilder().setListElement(bldr); - } else if (segment instanceof FieldReference.MapKey f) { - var bldr = Expression.ReferenceSegment.MapKey.newBuilder().setMapKey(toLiteral(f.key())); + } else if (segment instanceof FieldReference.MapKey) { + FieldReference.MapKey f = (FieldReference.MapKey) segment; + Expression.ReferenceSegment.MapKey.Builder bldr = + Expression.ReferenceSegment.MapKey.newBuilder().setMapKey(toLiteral(f.key())); if (seg != null) { bldr.setChild(seg); } @@ -522,11 +529,11 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { } else { throw new IllegalArgumentException("Unhandled type: " + segment); } - var builtSegment = protoSegment.build(); - seg = builtSegment; + seg = protoSegment.build(); } - var out = Expression.FieldReference.newBuilder().setDirectReference(seg); + Expression.FieldReference.Builder out = + Expression.FieldReference.newBuilder().setDirectReference(seg); if (expr.inputExpression().isPresent()) { out.setExpression(toProto(expr.inputExpression().get())); @@ -591,7 +598,8 @@ public Expression visit( io.substrait.expression.Expression.WindowFunctionInvocation expr, EmptyVisitationContext context) throws RuntimeException { - var argVisitor = FunctionArg.toProto(typeProtoConverter, this); + FunctionArg.FuncArgVisitor + argVisitor = FunctionArg.toProto(typeProtoConverter, this); List args = expr.arguments().stream() .map(a -> a.accept(expr.declaration(), 0, argVisitor, context)) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 4a4d6a4fc..52f780f9a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -3,12 +3,14 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; +import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; import io.substrait.proto.ConsistentPartitionWindowRel; +import io.substrait.proto.Expression.FieldReference.ReferenceTypeCase; import io.substrait.proto.FunctionArgument; import io.substrait.proto.SortField; import io.substrait.relation.ConsistentPartitionWindow; @@ -53,211 +55,241 @@ public ProtoExpressionConverter( } public FieldReference from(io.substrait.proto.Expression.FieldReference reference) { - switch (reference.getReferenceTypeCase()) { - case DIRECT_REFERENCE -> { - io.substrait.proto.Expression.ReferenceSegment segment = reference.getDirectReference(); + io.substrait.proto.Expression.FieldReference.ReferenceTypeCase refTypeCase = + reference.getReferenceTypeCase(); - var segments = new ArrayList(); - while (segment != io.substrait.proto.Expression.ReferenceSegment.getDefaultInstance()) { - segments.add( - switch (segment.getReferenceTypeCase()) { - case MAP_KEY -> { - var mapKey = segment.getMapKey(); - segment = mapKey.getChild(); - yield FieldReference.MapKey.of(from(mapKey.getMapKey())); - } - case STRUCT_FIELD -> { - var structField = segment.getStructField(); - segment = structField.getChild(); - yield FieldReference.StructField.of(structField.getField()); - } - case LIST_ELEMENT -> { - var listElement = segment.getListElement(); - segment = listElement.getChild(); - yield FieldReference.ListElement.of(listElement.getOffset()); - } - case REFERENCETYPE_NOT_SET -> throw new IllegalArgumentException( - "Unhandled type: " + segment.getReferenceTypeCase()); - }); - } - Collections.reverse(segments); - var fieldReference = - switch (reference.getRootTypeCase()) { - case EXPRESSION -> FieldReference.ofExpression( - from(reference.getExpression()), segments); - case ROOT_REFERENCE -> FieldReference.ofRoot(rootType, segments); - case OUTER_REFERENCE -> FieldReference.newRootStructOuterReference( - reference.getDirectReference().getStructField().getField(), - rootType, - reference.getOuterReference().getStepsOut()); - case ROOTTYPE_NOT_SET -> throw new IllegalArgumentException( - "Unhandled type: " + reference.getRootTypeCase()); - }; + if (refTypeCase == ReferenceTypeCase.MASKED_REFERENCE) { + throw new IllegalArgumentException("Unsupported type: " + refTypeCase); + } + + if (refTypeCase != ReferenceTypeCase.DIRECT_REFERENCE) { + throw new IllegalArgumentException("Unhandled type: " + refTypeCase); + } + + switch (reference.getRootTypeCase()) { + case EXPRESSION: + return FieldReference.ofExpression( + from(reference.getExpression()), + getDirectReferenceSegments(reference.getDirectReference())); + case ROOT_REFERENCE: + return FieldReference.ofRoot( + rootType, getDirectReferenceSegments(reference.getDirectReference())); + case OUTER_REFERENCE: + return FieldReference.newRootStructOuterReference( + reference.getDirectReference().getStructField().getField(), + rootType, + reference.getOuterReference().getStepsOut()); + case ROOTTYPE_NOT_SET: + default: + throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase()); + } + } + + private List getDirectReferenceSegments( + io.substrait.proto.Expression.ReferenceSegment segment) { + List results = new ArrayList<>(); - return fieldReference; + while (segment != io.substrait.proto.Expression.ReferenceSegment.getDefaultInstance()) { + final ReferenceSegment mappedSegment; + switch (segment.getReferenceTypeCase()) { + case MAP_KEY: + io.substrait.proto.Expression.ReferenceSegment.MapKey mapKey = segment.getMapKey(); + segment = mapKey.getChild(); + mappedSegment = FieldReference.MapKey.of(from(mapKey.getMapKey())); + break; + case STRUCT_FIELD: + io.substrait.proto.Expression.ReferenceSegment.StructField structField = + segment.getStructField(); + segment = structField.getChild(); + mappedSegment = FieldReference.StructField.of(structField.getField()); + break; + case LIST_ELEMENT: + io.substrait.proto.Expression.ReferenceSegment.ListElement listElement = + segment.getListElement(); + segment = listElement.getChild(); + mappedSegment = FieldReference.ListElement.of(listElement.getOffset()); + break; + case REFERENCETYPE_NOT_SET: + default: + throw new IllegalArgumentException("Unhandled type: " + segment.getReferenceTypeCase()); } - case MASKED_REFERENCE -> throw new IllegalArgumentException( - "Unsupported type: " + reference.getReferenceTypeCase()); - default -> throw new IllegalArgumentException( - "Unhandled type: " + reference.getReferenceTypeCase()); + + results.add(mappedSegment); } + + Collections.reverse(results); + + return results; } public Expression from(io.substrait.proto.Expression expr) { - return switch (expr.getRexTypeCase()) { - case LITERAL -> from(expr.getLiteral()); - case SELECTION -> from(expr.getSelection()); - case SCALAR_FUNCTION -> { - var scalarFunction = expr.getScalarFunction(); - var functionReference = scalarFunction.getFunctionReference(); - var declaration = lookup.getScalarFunction(functionReference, extensions); - var pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); - var args = - IntStream.range(0, scalarFunction.getArgumentsCount()) - .mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))) - .collect(java.util.stream.Collectors.toList()); - var options = - scalarFunction.getOptionsList().stream() - .map(ProtoExpressionConverter::fromFunctionOption) - .collect(Collectors.toList()); - yield Expression.ScalarFunctionInvocation.builder() - .addAllArguments(args) - .declaration(declaration) - .outputType(protoTypeConverter.from(scalarFunction.getOutputType())) - .options(options) - .build(); - } - case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction()); - case IF_THEN -> { - var ifThen = expr.getIfThen(); - var clauses = - ifThen.getIfsList().stream() - .map(t -> ExpressionCreator.ifThenClause(from(t.getIf()), from(t.getThen()))) - .collect(java.util.stream.Collectors.toList()); - yield ExpressionCreator.ifThenStatement(from(ifThen.getElse()), clauses); - } - case SWITCH_EXPRESSION -> { - var switchExpr = expr.getSwitchExpression(); - var clauses = - switchExpr.getIfsList().stream() - .map(t -> ExpressionCreator.switchClause(from(t.getIf()), from(t.getThen()))) - .collect(java.util.stream.Collectors.toList()); - yield ExpressionCreator.switchStatement( - from(switchExpr.getMatch()), from(switchExpr.getElse()), clauses); - } - case SINGULAR_OR_LIST -> { - var orList = expr.getSingularOrList(); - var values = - orList.getOptionsList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList()); - yield Expression.SingleOrList.builder() - .condition(from(orList.getValue())) - .addAllOptions(values) - .build(); - } - case MULTI_OR_LIST -> { - var multiOrList = expr.getMultiOrList(); - var values = - multiOrList.getOptionsList().stream() - .map( - t -> - Expression.MultiOrListRecord.builder() - .addAllValues( - t.getFieldsList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())) - .build()) - .collect(java.util.stream.Collectors.toList()); - yield Expression.MultiOrList.builder() - .addAllOptionCombinations(values) - .addAllConditions( - multiOrList.getValueList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())) - .build(); - } - case CAST -> ExpressionCreator.cast( - protoTypeConverter.from(expr.getCast().getType()), - from(expr.getCast().getInput()), - Expression.FailureBehavior.fromProto(expr.getCast().getFailureBehavior())); - case SUBQUERY -> { - switch (expr.getSubquery().getSubqueryTypeCase()) { - case SET_PREDICATE -> { - var rel = protoRelConverter.from(expr.getSubquery().getSetPredicate().getTuples()); - yield Expression.SetPredicate.builder() - .tuples(rel) - .predicateOp( - Expression.PredicateOp.fromProto( - expr.getSubquery().getSetPredicate().getPredicateOp())) - .build(); - } - case SCALAR -> { - var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput()); - yield Expression.ScalarSubquery.builder() - .input(rel) - .type( - rel.getRecordType() - .accept( - new TypeVisitor.TypeThrowsVisitor( - "Expected struct field") { - @Override - public Type visit(Type.Struct type) throws RuntimeException { - if (type.fields().size() != 1) { - throw new UnsupportedOperationException( - "Scalar subquery must have exactly one field"); - } - // Result can be null if the query returns no rows - return type.fields().get(0); - } - })) - .build(); - } - case IN_PREDICATE -> { - var rel = protoRelConverter.from(expr.getSubquery().getInPredicate().getHaystack()); - var needles = - expr.getSubquery().getInPredicate().getNeedlesList().stream() - .map(e -> this.from(e)) - .collect(java.util.stream.Collectors.toList()); - yield Expression.InPredicate.builder().haystack(rel).needles(needles).build(); - } - case SET_COMPARISON -> { - throw new UnsupportedOperationException( - "Unsupported subquery type: " + expr.getSubquery().getSubqueryTypeCase()); - } - default -> { - throw new IllegalArgumentException( - "Unknown subquery type: " + expr.getSubquery().getSubqueryTypeCase()); + switch (expr.getRexTypeCase()) { + case LITERAL: + return from(expr.getLiteral()); + case SELECTION: + return from(expr.getSelection()); + case SCALAR_FUNCTION: + { + io.substrait.proto.Expression.ScalarFunction scalarFunction = expr.getScalarFunction(); + int functionReference = scalarFunction.getFunctionReference(); + SimpleExtension.ScalarFunctionVariant declaration = + lookup.getScalarFunction(functionReference, extensions); + FunctionArg.ProtoFrom pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); + List args = + IntStream.range(0, scalarFunction.getArgumentsCount()) + .mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))) + .collect(Collectors.toList()); + List options = + scalarFunction.getOptionsList().stream() + .map(ProtoExpressionConverter::fromFunctionOption) + .collect(Collectors.toList()); + return Expression.ScalarFunctionInvocation.builder() + .addAllArguments(args) + .declaration(declaration) + .outputType(protoTypeConverter.from(scalarFunction.getOutputType())) + .options(options) + .build(); + } + case WINDOW_FUNCTION: + return fromWindowFunction(expr.getWindowFunction()); + case IF_THEN: + { + io.substrait.proto.Expression.IfThen ifThen = expr.getIfThen(); + List clauses = + ifThen.getIfsList().stream() + .map(t -> ExpressionCreator.ifThenClause(from(t.getIf()), from(t.getThen()))) + .collect(Collectors.toList()); + return ExpressionCreator.ifThenStatement(from(ifThen.getElse()), clauses); + } + case SWITCH_EXPRESSION: + { + io.substrait.proto.Expression.SwitchExpression switchExpr = expr.getSwitchExpression(); + List clauses = + switchExpr.getIfsList().stream() + .map(t -> ExpressionCreator.switchClause(from(t.getIf()), from(t.getThen()))) + .collect(Collectors.toList()); + return ExpressionCreator.switchStatement( + from(switchExpr.getMatch()), from(switchExpr.getElse()), clauses); + } + case SINGULAR_OR_LIST: + { + io.substrait.proto.Expression.SingularOrList orList = expr.getSingularOrList(); + List values = + orList.getOptionsList().stream().map(this::from).collect(Collectors.toList()); + return Expression.SingleOrList.builder() + .condition(from(orList.getValue())) + .addAllOptions(values) + .build(); + } + case MULTI_OR_LIST: + { + io.substrait.proto.Expression.MultiOrList multiOrList = expr.getMultiOrList(); + List values = + multiOrList.getOptionsList().stream() + .map( + t -> + Expression.MultiOrListRecord.builder() + .addAllValues( + t.getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())) + .build()) + .collect(Collectors.toList()); + return Expression.MultiOrList.builder() + .addAllOptionCombinations(values) + .addAllConditions( + multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList())) + .build(); + } + case CAST: + return ExpressionCreator.cast( + protoTypeConverter.from(expr.getCast().getType()), + from(expr.getCast().getInput()), + Expression.FailureBehavior.fromProto(expr.getCast().getFailureBehavior())); + case SUBQUERY: + { + switch (expr.getSubquery().getSubqueryTypeCase()) { + case SET_PREDICATE: + { + io.substrait.relation.Rel rel = + protoRelConverter.from(expr.getSubquery().getSetPredicate().getTuples()); + return Expression.SetPredicate.builder() + .tuples(rel) + .predicateOp( + Expression.PredicateOp.fromProto( + expr.getSubquery().getSetPredicate().getPredicateOp())) + .build(); + } + case SCALAR: + { + io.substrait.relation.Rel rel = + protoRelConverter.from(expr.getSubquery().getScalar().getInput()); + return Expression.ScalarSubquery.builder() + .input(rel) + .type( + rel.getRecordType() + .accept( + new TypeVisitor.TypeThrowsVisitor( + "Expected struct field") { + @Override + public Type visit(Type.Struct type) throws RuntimeException { + if (type.fields().size() != 1) { + throw new UnsupportedOperationException( + "Scalar subquery must have exactly one field"); + } + // Result can be null if the query returns no rows + return type.fields().get(0); + } + })) + .build(); + } + case IN_PREDICATE: + { + io.substrait.relation.Rel rel = + protoRelConverter.from(expr.getSubquery().getInPredicate().getHaystack()); + List needles = + expr.getSubquery().getInPredicate().getNeedlesList().stream() + .map(e -> this.from(e)) + .collect(Collectors.toList()); + return Expression.InPredicate.builder().haystack(rel).needles(needles).build(); + } + case SET_COMPARISON: + throw new UnsupportedOperationException( + "Unsupported subquery type: " + expr.getSubquery().getSubqueryTypeCase()); + default: + throw new IllegalArgumentException( + "Unknown subquery type: " + expr.getSubquery().getSubqueryTypeCase()); } } - } // TODO enum. - case ENUM -> throw new UnsupportedOperationException( - "Unsupported type: " + expr.getRexTypeCase()); - default -> throw new IllegalArgumentException("Unknown type: " + expr.getRexTypeCase()); - }; + case ENUM: + throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase()); + default: + throw new IllegalArgumentException("Unknown type: " + expr.getRexTypeCase()); + } } public Expression.WindowFunctionInvocation fromWindowFunction( io.substrait.proto.Expression.WindowFunction windowFunction) { - var functionReference = windowFunction.getFunctionReference(); - var declaration = lookup.getWindowFunction(functionReference, extensions); - var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); + int functionReference = windowFunction.getFunctionReference(); + SimpleExtension.WindowFunctionVariant declaration = + lookup.getWindowFunction(functionReference, extensions); + FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); - var args = + List args = fromFunctionArgumentList( windowFunction.getArgumentsCount(), argVisitor, declaration, windowFunction::getArguments); - var partitionExprs = + List partitionExprs = windowFunction.getPartitionsList().stream().map(this::from).collect(Collectors.toList()); - var sortFields = + List sortFields = windowFunction.getSortsList().stream() .map(this::fromSortField) .collect(Collectors.toList()); - var options = + List options = windowFunction.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); @@ -282,17 +314,18 @@ public Expression.WindowFunctionInvocation fromWindowFunction( public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFunction( ConsistentPartitionWindowRel.WindowRelFunction windowRelFunction) { - var functionReference = windowRelFunction.getFunctionReference(); - var declaration = lookup.getWindowFunction(functionReference, extensions); - var argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); + int functionReference = windowRelFunction.getFunctionReference(); + SimpleExtension.WindowFunctionVariant declaration = + lookup.getWindowFunction(functionReference, extensions); + FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); - var args = + List args = fromFunctionArgumentList( windowRelFunction.getArgumentsCount(), argVisitor, declaration, windowRelFunction::getArguments); - var options = + List options = windowRelFunction.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); @@ -314,125 +347,163 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti } private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { - return switch (bound.getKindCase()) { - case PRECEDING -> WindowBound.Preceding.of(bound.getPreceding().getOffset()); - case FOLLOWING -> WindowBound.Following.of(bound.getFollowing().getOffset()); - case CURRENT_ROW -> WindowBound.CURRENT_ROW; - case UNBOUNDED -> WindowBound.UNBOUNDED; - case KIND_NOT_SET -> - // per the spec, the lower and upper bounds default to the start or end of the partition - // respectively if not set - WindowBound.UNBOUNDED; - }; + switch (bound.getKindCase()) { + case PRECEDING: + return WindowBound.Preceding.of(bound.getPreceding().getOffset()); + case FOLLOWING: + return WindowBound.Following.of(bound.getFollowing().getOffset()); + case CURRENT_ROW: + return WindowBound.CURRENT_ROW; + case UNBOUNDED: + return WindowBound.UNBOUNDED; + case KIND_NOT_SET: + // per the spec, the lower and upper bounds default to the start or end of the partition + // respectively if not set + return WindowBound.UNBOUNDED; + default: + throw new IllegalArgumentException("Unsupported bound kind: " + bound.getKindCase()); + } } public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { - return switch (literal.getLiteralTypeCase()) { - case BOOLEAN -> ExpressionCreator.bool(literal.getNullable(), literal.getBoolean()); - case I8 -> ExpressionCreator.i8(literal.getNullable(), literal.getI8()); - case I16 -> ExpressionCreator.i16(literal.getNullable(), literal.getI16()); - case I32 -> ExpressionCreator.i32(literal.getNullable(), literal.getI32()); - case I64 -> ExpressionCreator.i64(literal.getNullable(), literal.getI64()); - case FP32 -> ExpressionCreator.fp32(literal.getNullable(), literal.getFp32()); - case FP64 -> ExpressionCreator.fp64(literal.getNullable(), literal.getFp64()); - case STRING -> ExpressionCreator.string(literal.getNullable(), literal.getString()); - case BINARY -> ExpressionCreator.binary(literal.getNullable(), literal.getBinary()); - case TIMESTAMP -> ExpressionCreator.timestamp(literal.getNullable(), literal.getTimestamp()); - case TIMESTAMP_TZ -> ExpressionCreator.timestampTZ( - literal.getNullable(), literal.getTimestampTz()); - case PRECISION_TIMESTAMP -> ExpressionCreator.precisionTimestamp( - literal.getNullable(), - literal.getPrecisionTimestamp().getValue(), - literal.getPrecisionTimestamp().getPrecision()); - case PRECISION_TIMESTAMP_TZ -> ExpressionCreator.precisionTimestampTZ( - literal.getNullable(), - literal.getPrecisionTimestampTz().getValue(), - literal.getPrecisionTimestampTz().getPrecision()); - case DATE -> ExpressionCreator.date(literal.getNullable(), literal.getDate()); - case TIME -> ExpressionCreator.time(literal.getNullable(), literal.getTime()); - case INTERVAL_YEAR_TO_MONTH -> ExpressionCreator.intervalYear( - literal.getNullable(), - literal.getIntervalYearToMonth().getYears(), - literal.getIntervalYearToMonth().getMonths()); - case INTERVAL_DAY_TO_SECOND -> { - // Handle deprecated version that doesn't provide precision and that uses microseconds - // instead of subseconds, for backwards compatibility - int precision = - literal.getIntervalDayToSecond().hasPrecision() - ? literal.getIntervalDayToSecond().getPrecision() - : 6; // microseconds - long subseconds = - literal.getIntervalDayToSecond().hasPrecision() - ? literal.getIntervalDayToSecond().getSubseconds() - : literal.getIntervalDayToSecond().getMicroseconds(); - yield ExpressionCreator.intervalDay( + switch (literal.getLiteralTypeCase()) { + case BOOLEAN: + return ExpressionCreator.bool(literal.getNullable(), literal.getBoolean()); + case I8: + return ExpressionCreator.i8(literal.getNullable(), literal.getI8()); + case I16: + return ExpressionCreator.i16(literal.getNullable(), literal.getI16()); + case I32: + return ExpressionCreator.i32(literal.getNullable(), literal.getI32()); + case I64: + return ExpressionCreator.i64(literal.getNullable(), literal.getI64()); + case FP32: + return ExpressionCreator.fp32(literal.getNullable(), literal.getFp32()); + case FP64: + return ExpressionCreator.fp64(literal.getNullable(), literal.getFp64()); + case STRING: + return ExpressionCreator.string(literal.getNullable(), literal.getString()); + case BINARY: + return ExpressionCreator.binary(literal.getNullable(), literal.getBinary()); + case TIMESTAMP: + return ExpressionCreator.timestamp(literal.getNullable(), literal.getTimestamp()); + case TIMESTAMP_TZ: + return ExpressionCreator.timestampTZ(literal.getNullable(), literal.getTimestampTz()); + case PRECISION_TIMESTAMP: + return ExpressionCreator.precisionTimestamp( literal.getNullable(), - literal.getIntervalDayToSecond().getDays(), - literal.getIntervalDayToSecond().getSeconds(), - subseconds, - precision); - } - case INTERVAL_COMPOUND -> { - if (!literal.getIntervalCompound().getIntervalDayToSecond().hasPrecision()) { - throw new RuntimeException( - "Interval compound with deprecated version of interval day (ie. no precision) is not supported"); + literal.getPrecisionTimestamp().getValue(), + literal.getPrecisionTimestamp().getPrecision()); + case PRECISION_TIMESTAMP_TZ: + return ExpressionCreator.precisionTimestampTZ( + literal.getNullable(), + literal.getPrecisionTimestampTz().getValue(), + literal.getPrecisionTimestampTz().getPrecision()); + case DATE: + return ExpressionCreator.date(literal.getNullable(), literal.getDate()); + case TIME: + return ExpressionCreator.time(literal.getNullable(), literal.getTime()); + case INTERVAL_YEAR_TO_MONTH: + return ExpressionCreator.intervalYear( + literal.getNullable(), + literal.getIntervalYearToMonth().getYears(), + literal.getIntervalYearToMonth().getMonths()); + case INTERVAL_DAY_TO_SECOND: + { + // Handle deprecated version that doesn't provide precision and that uses microseconds + // instead of subseconds, for backwards compatibility + int precision = + literal.getIntervalDayToSecond().hasPrecision() + ? literal.getIntervalDayToSecond().getPrecision() + : 6; // microseconds + long subseconds = + literal.getIntervalDayToSecond().hasPrecision() + ? literal.getIntervalDayToSecond().getSubseconds() + : literal.getIntervalDayToSecond().getMicroseconds(); + return ExpressionCreator.intervalDay( + literal.getNullable(), + literal.getIntervalDayToSecond().getDays(), + literal.getIntervalDayToSecond().getSeconds(), + subseconds, + precision); + } + case INTERVAL_COMPOUND: + { + if (!literal.getIntervalCompound().getIntervalDayToSecond().hasPrecision()) { + throw new RuntimeException( + "Interval compound with deprecated version of interval day (ie. no precision) is not supported"); + } + return ExpressionCreator.intervalCompound( + literal.getNullable(), + literal.getIntervalCompound().getIntervalYearToMonth().getYears(), + literal.getIntervalCompound().getIntervalYearToMonth().getMonths(), + literal.getIntervalCompound().getIntervalDayToSecond().getDays(), + literal.getIntervalCompound().getIntervalDayToSecond().getSeconds(), + literal.getIntervalCompound().getIntervalDayToSecond().getSubseconds(), + literal.getIntervalCompound().getIntervalDayToSecond().getPrecision()); } - yield ExpressionCreator.intervalCompound( + case FIXED_CHAR: + return ExpressionCreator.fixedChar(literal.getNullable(), literal.getFixedChar()); + case VAR_CHAR: + return ExpressionCreator.varChar( literal.getNullable(), - literal.getIntervalCompound().getIntervalYearToMonth().getYears(), - literal.getIntervalCompound().getIntervalYearToMonth().getMonths(), - literal.getIntervalCompound().getIntervalDayToSecond().getDays(), - literal.getIntervalCompound().getIntervalDayToSecond().getSeconds(), - literal.getIntervalCompound().getIntervalDayToSecond().getSubseconds(), - literal.getIntervalCompound().getIntervalDayToSecond().getPrecision()); - } - case FIXED_CHAR -> ExpressionCreator.fixedChar(literal.getNullable(), literal.getFixedChar()); - case VAR_CHAR -> ExpressionCreator.varChar( - literal.getNullable(), literal.getVarChar().getValue(), literal.getVarChar().getLength()); - case FIXED_BINARY -> ExpressionCreator.fixedBinary( - literal.getNullable(), literal.getFixedBinary()); - case DECIMAL -> ExpressionCreator.decimal( - literal.getNullable(), - literal.getDecimal().getValue(), - literal.getDecimal().getPrecision(), - literal.getDecimal().getScale()); - case STRUCT -> ExpressionCreator.struct( - literal.getNullable(), - literal.getStruct().getFieldsList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())); - case MAP -> ExpressionCreator.map( - literal.getNullable(), - literal.getMap().getKeyValuesList().stream() - .collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue())))); - case EMPTY_MAP -> { - // literal.getNullable() is intentionally ignored in favor of the nullability - // specified in the literal.getEmptyMap() type. - var mapType = protoTypeConverter.fromMap(literal.getEmptyMap()); - yield ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value()); - } - case UUID -> ExpressionCreator.uuid(literal.getNullable(), literal.getUuid()); - case NULL -> ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull())); - case LIST -> ExpressionCreator.list( - literal.getNullable(), - literal.getList().getValuesList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())); - case EMPTY_LIST -> { - // literal.getNullable() is intentionally ignored in favor of the nullability - // specified in the literal.getEmptyList() type. - var listType = protoTypeConverter.fromList(literal.getEmptyList()); - yield ExpressionCreator.emptyList(listType.nullable(), listType.elementType()); - } - case USER_DEFINED -> { - var userDefinedLiteral = literal.getUserDefined(); - var type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); - yield ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue()); - } - default -> throw new IllegalStateException( - "Unexpected value: " + literal.getLiteralTypeCase()); - }; + literal.getVarChar().getValue(), + literal.getVarChar().getLength()); + case FIXED_BINARY: + return ExpressionCreator.fixedBinary(literal.getNullable(), literal.getFixedBinary()); + case DECIMAL: + return ExpressionCreator.decimal( + literal.getNullable(), + literal.getDecimal().getValue(), + literal.getDecimal().getPrecision(), + literal.getDecimal().getScale()); + case STRUCT: + return ExpressionCreator.struct( + literal.getNullable(), + literal.getStruct().getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())); + case MAP: + return ExpressionCreator.map( + literal.getNullable(), + literal.getMap().getKeyValuesList().stream() + .collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue())))); + case EMPTY_MAP: + { + // literal.getNullable() is intentionally ignored in favor of the nullability + // specified in the literal.getEmptyMap() type. + Type.Map mapType = protoTypeConverter.fromMap(literal.getEmptyMap()); + return ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value()); + } + case UUID: + return ExpressionCreator.uuid(literal.getNullable(), literal.getUuid()); + case NULL: + return ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull())); + case LIST: + return ExpressionCreator.list( + literal.getNullable(), + literal.getList().getValuesList().stream() + .map(this::from) + .collect(Collectors.toList())); + case EMPTY_LIST: + { + // literal.getNullable() is intentionally ignored in favor of the nullability + // specified in the literal.getEmptyList() type. + Type.ListType listType = protoTypeConverter.fromList(literal.getEmptyList()); + return ExpressionCreator.emptyList(listType.nullable(), listType.elementType()); + } + case USER_DEFINED: + { + io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = + literal.getUserDefined(); + SimpleExtension.Type type = + lookup.getType(userDefinedLiteral.getTypeReference(), extensions); + return ExpressionCreator.userDefinedLiteral( + literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue()); + } + default: + throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); + } } private static List fromFunctionArgumentList( diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index 00c340b53..e0cb9d1c5 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -26,7 +26,10 @@ public ExtendedExpression toProto( for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReferenceBase expressionReference : extendedExpression.getReferredExpressions()) { if (expressionReference - instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et) { + instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference) { + io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et = + (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference) + expressionReference; io.substrait.proto.Expression expressionProto = et.getExpression().accept(expressionProtoConverter, EmptyVisitationContext.INSTANCE); ExpressionReference.Builder expressionReferenceBuilder = @@ -36,8 +39,10 @@ public ExtendedExpression toProto( builder.addReferredExpr(expressionReferenceBuilder); } else if (expressionReference instanceof - io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference - aft) { + io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference) { + io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference aft = + (io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference) + expressionReference; ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setMeasure( diff --git a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java index 51aa38ebd..a2ad68f70 100644 --- a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java @@ -15,7 +15,7 @@ public AbstractExtensionLookup( public SimpleExtension.ScalarFunctionVariant getScalarFunction( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = functionAnchorMap.get(reference); + SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -26,7 +26,7 @@ public SimpleExtension.ScalarFunctionVariant getScalarFunction( public SimpleExtension.WindowFunctionVariant getWindowFunction( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = functionAnchorMap.get(reference); + SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -37,7 +37,7 @@ public SimpleExtension.WindowFunctionVariant getWindowFunction( public SimpleExtension.AggregateFunctionVariant getAggregateFunction( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = functionAnchorMap.get(reference); + SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -48,7 +48,7 @@ public SimpleExtension.AggregateFunctionVariant getAggregateFunction( public SimpleExtension.Type getType( int reference, SimpleExtension.ExtensionCollection extensions) { - var anchor = typeAnchorMap.get(reference); + SimpleExtension.TypeAnchor anchor = typeAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown type id. Make sure that the type id provided was shared in the extensions section of the plan."); diff --git a/core/src/main/java/io/substrait/extension/AdvancedExtension.java b/core/src/main/java/io/substrait/extension/AdvancedExtension.java index fb5370244..bd717a636 100644 --- a/core/src/main/java/io/substrait/extension/AdvancedExtension.java +++ b/core/src/main/java/io/substrait/extension/AdvancedExtension.java @@ -14,7 +14,8 @@ public abstract class AdvancedExtension { public abstract Optional getEnhancement(); public io.substrait.proto.AdvancedExtension toProto(RelProtoConverter relProtoConverter) { - var builder = io.substrait.proto.AdvancedExtension.newBuilder(); + io.substrait.proto.AdvancedExtension.Builder builder = + io.substrait.proto.AdvancedExtension.newBuilder(); getEnhancement().ifPresent(e -> builder.setEnhancement(e.toProto(relProtoConverter))); getOptimizations().forEach(e -> builder.addOptimization(e.toProto(relProtoConverter))); return builder.build(); diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 5cfb03589..3dc0fe7c1 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,6 +1,5 @@ package io.substrait.extension; -import com.github.bsideup.jabel.Desugar; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; @@ -56,23 +55,23 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris().values()); - builder.addAllExtensions(simpleExtensions.extensionList()); + builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensions(simpleExtensions.extensionList); } public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris().values()); - builder.addAllExtensions(simpleExtensions.extensionList()); + builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensions(simpleExtensions.extensionList); } private SimpleExtensions getExtensions() { - var uriPos = new AtomicInteger(1); - var uris = new HashMap(); + AtomicInteger uriPos = new AtomicInteger(1); + HashMap uris = new HashMap<>(); - var extensionList = new ArrayList(); - for (var e : funcMap.forwardMap.entrySet()) { + ArrayList extensionList = new ArrayList<>(); + for (Map.Entry e : funcMap.forwardMap.entrySet()) { SimpleExtensionURI uri = uris.computeIfAbsent( e.getValue().namespace(), @@ -81,7 +80,7 @@ private SimpleExtensions getExtensions() { .setExtensionUriAnchor(uriPos.getAndIncrement()) .setUri(k) .build()); - var decl = + SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() .setExtensionFunction( SimpleExtensionDeclaration.ExtensionFunction.newBuilder() @@ -91,7 +90,7 @@ private SimpleExtensions getExtensions() { .build(); extensionList.add(decl); } - for (var e : typeMap.forwardMap.entrySet()) { + for (Map.Entry e : typeMap.forwardMap.entrySet()) { SimpleExtensionURI uri = uris.computeIfAbsent( e.getValue().namespace(), @@ -100,7 +99,7 @@ private SimpleExtensions getExtensions() { .setExtensionUriAnchor(uriPos.getAndIncrement()) .setUri(k) .build()); - var decl = + SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() .setExtensionType( SimpleExtensionDeclaration.ExtensionType.newBuilder() @@ -113,10 +112,17 @@ private SimpleExtensions getExtensions() { return new SimpleExtensions(uris, extensionList); } - @Desugar - private record SimpleExtensions( - HashMap uris, - ArrayList extensionList) {} + private static final class SimpleExtensions { + final HashMap uris; + final ArrayList extensionList; + + SimpleExtensions( + HashMap uris, + ArrayList extensionList) { + this.uris = uris; + this.extensionList = extensionList; + } + } /** We don't depend on guava... */ private static class BidiMap { diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 70034d9b1..7e1c81417 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -46,12 +46,12 @@ private Builder from( List simpleExtensionURIs, List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); - for (var extension : simpleExtensionURIs) { + for (SimpleExtensionURI extension : simpleExtensionURIs) { namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap - for (var extension : simpleExtensionDeclarations) { + for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { continue; } @@ -68,7 +68,7 @@ private Builder from( } // Add all types used in plan to the typeMap - for (var extension : simpleExtensionDeclarations) { + for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { if (!extension.hasExtensionType()) { continue; } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 56cf2f2ba..5c090ce1c 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -271,9 +271,7 @@ public FunctionAnchor getAnchor() { private final Supplier> requiredArgsSupplier = Util.memoize( () -> { - return args().stream() - .filter(Argument::required) - .collect(java.util.stream.Collectors.toList()); + return args().stream().filter(Argument::required).collect(Collectors.toList()); }); public static String constructKeyFromTypes( @@ -656,7 +654,7 @@ private void checkNamespace(String name) { } public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { - var variant = aggregateFunctionsLookup.get().get(anchor); + AggregateFunctionVariant variant = aggregateFunctionsLookup.get().get(anchor); if (variant != null) { return variant; } @@ -670,7 +668,7 @@ public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { } public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { - var variant = windowFunctionsLookup.get().get(anchor); + WindowFunctionVariant variant = windowFunctionsLookup.get().get(anchor); if (variant != null) { return variant; } @@ -697,7 +695,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { } public static ExtensionCollection loadDefaults() { - var defaultFiles = + List defaultFiles = Arrays.asList( "boolean", "aggregate_generic", @@ -712,7 +710,7 @@ public static ExtensionCollection loadDefaults() { "string") .stream() .map(c -> String.format("/functions_%s.yaml", c)) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); return load(defaultFiles); } @@ -722,17 +720,17 @@ public static ExtensionCollection load(List resourcePaths) { throw new IllegalArgumentException("Require at least one resource path."); } - var extensions = + List extensions = resourcePaths.stream() .map( path -> { - try (var stream = ExtensionCollection.class.getResourceAsStream(path)) { + try (InputStream stream = ExtensionCollection.class.getResourceAsStream(path)) { return load(path, stream); } catch (IOException e) { throw new RuntimeException(e); } }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); ExtensionCollection complete = extensions.get(0); for (int i = 1; i < extensions.size(); i++) { complete = complete.merge(extensions.get(i)); @@ -742,7 +740,7 @@ public static ExtensionCollection load(List resourcePaths) { public static ExtensionCollection load(String namespace, String str) { try { - var doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class); + ExtensionSignatures doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class); return buildExtensionCollection(namespace, doc); } catch (JsonProcessingException e) { throw new RuntimeException(e); @@ -751,7 +749,8 @@ public static ExtensionCollection load(String namespace, String str) { public static ExtensionCollection load(String namespace, InputStream stream) { try { - var doc = objectMapper(namespace).readValue(stream, ExtensionSignatures.class); + ExtensionSignatures doc = + objectMapper(namespace).readValue(stream, ExtensionSignatures.class); return buildExtensionCollection(namespace, doc); } catch (RuntimeException ex) { throw ex; @@ -765,12 +764,12 @@ public static ExtensionCollection buildExtensionCollection( List scalarFunctionVariants = extensionSignatures.scalars().stream() .flatMap(t -> t.resolve(namespace)) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); List aggregateFunctionVariants = extensionSignatures.aggregates().stream() .flatMap(t -> t.resolve(namespace)) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); Stream windowFunctionVariants = extensionSignatures.windows().stream().flatMap(t -> t.resolve(namespace)); @@ -794,7 +793,7 @@ public static ExtensionCollection buildExtensionCollection( Stream.concat(windowFunctionVariants, windowAggFunctionVariants) .collect(Collectors.toList()); - var collection = + ImmutableSimpleExtension.ExtensionCollection collection = ImmutableSimpleExtension.ExtensionCollection.builder() .scalarFunctions(scalarFunctionVariants) .aggregateFunctions(aggregateFunctionVariants) diff --git a/core/src/main/java/io/substrait/hint/Hint.java b/core/src/main/java/io/substrait/hint/Hint.java index 238bf44e7..7497f0529 100644 --- a/core/src/main/java/io/substrait/hint/Hint.java +++ b/core/src/main/java/io/substrait/hint/Hint.java @@ -12,7 +12,8 @@ public abstract class Hint { public abstract List getOutputNames(); public RelCommon.Hint toProto() { - var builder = RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames()); + RelCommon.Hint.Builder builder = + RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames()); getAlias().ifPresent(builder::setAlias); return builder.build(); } diff --git a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java index 9accc5de2..8d7947096 100644 --- a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java @@ -33,7 +33,7 @@ public DdlRel.DdlObject toProto() { } public static DdlObject fromProto(DdlRel.DdlObject proto) { - for (var v : values()) { + for (DdlObject v : values()) { if (v.proto == proto) { return v; } @@ -61,7 +61,7 @@ public DdlRel.DdlOp toProto() { } public static DdlOp fromProto(DdlRel.DdlOp proto) { - for (var v : values()) { + for (DdlOp v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/relation/AbstractWriteRel.java b/core/src/main/java/io/substrait/relation/AbstractWriteRel.java index 23014724f..fe7f49bee 100644 --- a/core/src/main/java/io/substrait/relation/AbstractWriteRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractWriteRel.java @@ -33,7 +33,7 @@ public WriteRel.WriteOp toProto() { } public static WriteOp fromProto(WriteRel.WriteOp proto) { - for (var v : values()) { + for (WriteOp v : values()) { if (v.proto == proto) { return v; } @@ -60,7 +60,7 @@ public WriteRel.CreateMode toProto() { } public static CreateMode fromProto(WriteRel.CreateMode proto) { - for (var v : values()) { + for (CreateMode v : values()) { if (v.proto == proto) { return v; } @@ -85,7 +85,7 @@ public WriteRel.OutputMode toProto() { } public static OutputMode fromProto(WriteRel.OutputMode proto) { - for (var v : values()) { + for (OutputMode v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java index 7a9d0f569..08776ed5d 100644 --- a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java @@ -3,9 +3,13 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.proto.AggregateFunction; +import io.substrait.proto.FunctionArgument; import io.substrait.type.proto.TypeProtoConverter; import io.substrait.util.EmptyVisitationContext; +import java.util.List; +import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -25,9 +29,10 @@ public AggregateFunctionProtoConverter(ExtensionCollector functionCollector) { } public AggregateFunction toProto(Aggregate.Measure measure) { - var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - var args = measure.getFunction().arguments(); - var aggFuncDef = measure.getFunction().declaration(); + FunctionArg.FuncArgVisitor + argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + List args = measure.getFunction().arguments(); + SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); return AggregateFunction.newBuilder() .setPhase(measure.getFunction().aggregationPhase().toProto()) @@ -39,7 +44,7 @@ public AggregateFunction toProto(Aggregate.Measure measure) { i -> args.get(i) .accept(aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .setFunctionReference( functionCollector.getFunctionReference(measure.getFunction().declaration())) .build(); diff --git a/core/src/main/java/io/substrait/relation/Expand.java b/core/src/main/java/io/substrait/relation/Expand.java index 63e868f63..1d07ea6ca 100644 --- a/core/src/main/java/io/substrait/relation/Expand.java +++ b/core/src/main/java/io/substrait/relation/Expand.java @@ -53,8 +53,8 @@ public abstract static class SwitchingField implements ExpandField { public abstract List getDuplicates(); public Type getType() { - var nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable()); - var type = getDuplicates().get(0).getType(); + boolean nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable()); + Type type = getDuplicates().get(0).getType(); return nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type); } diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 8709f9d02..57132a940 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -212,9 +212,10 @@ public Optional visit( @Override public Optional visit(Expression.Switch expr, EmptyVisitationContext context) throws E { - var match = expr.match().accept(this, context); - var switchClauses = transformList(expr.switchClauses(), context, this::visitSwitchClause); - var defaultClause = expr.defaultClause().accept(this, context); + Optional match = expr.match().accept(this, context); + Optional> switchClauses = + transformList(expr.switchClauses(), context, this::visitSwitchClause); + Optional defaultClause = expr.defaultClause().accept(this, context); if (allEmpty(match, switchClauses, defaultClause)) { return Optional.empty(); @@ -242,8 +243,9 @@ protected Optional visitSwitchClause( @Override public Optional visit(Expression.IfThen ifThen, EmptyVisitationContext context) throws E { - var ifClauses = transformList(ifThen.ifClauses(), context, this::visitIfClause); - var elseClause = ifThen.elseClause().accept(this, context); + Optional> ifClauses = + transformList(ifThen.ifClauses(), context, this::visitIfClause); + Optional elseClause = ifThen.elseClause().accept(this, context); if (allEmpty(ifClauses, elseClause)) { return Optional.empty(); @@ -258,8 +260,8 @@ public Optional visit(Expression.IfThen ifThen, EmptyVisitationConte protected Optional visitIfClause( Expression.IfClause ifClause, EmptyVisitationContext context) throws E { - var condition = ifClause.condition().accept(this, context); - var then = ifClause.then().accept(this, context); + Optional condition = ifClause.condition().accept(this, context); + Optional then = ifClause.then().accept(this, context); if (allEmpty(condition, then)) { return Optional.empty(); @@ -287,9 +289,10 @@ public Optional visit( @Override public Optional visit( Expression.WindowFunctionInvocation wfi, EmptyVisitationContext context) throws E { - var arguments = visitFunctionArguments(wfi.arguments(), context); - var partitionBy = visitExprList(wfi.partitionBy(), context); - var sort = transformList(wfi.sort(), context, this::visitSortField); + Optional> arguments = visitFunctionArguments(wfi.arguments(), context); + Optional> partitionBy = visitExprList(wfi.partitionBy(), context); + Optional> sort = + transformList(wfi.sort(), context, this::visitSortField); if (allEmpty(arguments, partitionBy, sort)) { return Optional.empty(); @@ -313,8 +316,8 @@ public Optional visit(Expression.Cast cast, EmptyVisitationContext c @Override public Optional visit( Expression.SingleOrList singleOrList, EmptyVisitationContext context) throws E { - var condition = singleOrList.condition().accept(this, context); - var options = visitExprList(singleOrList.options(), context); + Optional condition = singleOrList.condition().accept(this, context); + Optional> options = visitExprList(singleOrList.options(), context); if (allEmpty(condition, options)) { return Optional.empty(); @@ -330,8 +333,8 @@ public Optional visit( @Override public Optional visit( Expression.MultiOrList multiOrList, EmptyVisitationContext context) throws E { - var conditions = visitExprList(multiOrList.conditions(), context); - var optionCombinations = + Optional> conditions = visitExprList(multiOrList.conditions(), context); + Optional> optionCombinations = transformList(multiOrList.optionCombinations(), context, this::visitMultiOrListRecord); if (allEmpty(conditions, optionCombinations)) { @@ -359,7 +362,8 @@ protected Optional visitMultiOrListRecord( @Override public Optional visit(FieldReference fieldReference, EmptyVisitationContext context) throws E { - var inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); + Optional inputExpression = + visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { return Optional.empty(); @@ -389,8 +393,8 @@ public Optional visit( @Override public Optional visit( Expression.InPredicate inPredicate, EmptyVisitationContext context) throws E { - var haystack = inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context); - var needles = visitExprList(inPredicate.needles(), context); + Optional haystack = inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context); + Optional> needles = visitExprList(inPredicate.needles(), context); if (allEmpty(haystack, needles)) { return Optional.empty(); @@ -425,8 +429,8 @@ protected Optional> visitFunctionArguments( funcArgs, context, (arg, c) -> { - if (arg instanceof Expression expr) { - return expr.accept(this, c).flatMap(Optional::of); + if (arg instanceof Expression) { + return ((Expression) arg).accept(this, c).flatMap(Optional::of); } else { return Optional.empty(); } diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index 490bd7315..adb9cd535 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -51,7 +51,7 @@ public JoinRel.JoinType toProto() { } public static JoinType fromProto(JoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -63,33 +63,49 @@ public static JoinType fromProto(JoinRel.JoinType proto) { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER, RIGHT_SINGLE -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_SEMI, RIGHT_ANTI -> Stream - .of(); // these are right joins which ignore left side columns - case RIGHT_MARK -> Stream.of( - TypeCreator.REQUIRED - .BOOLEAN); // right mark join keeps all fields from right and adds a boolean mark - // field - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER, LEFT_SINGLE -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case SEMI, ANTI, LEFT_SEMI, LEFT_ANTI -> Stream - .of(); // these are left joins which ignore right side columns - case LEFT_MARK -> Stream.of( - TypeCreator.REQUIRED - .BOOLEAN); // left mark join keeps all fields from left and adds a boolean mark - // field - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case RIGHT: + case OUTER: + case RIGHT_SINGLE: + return getLeft().getRecordType().fields().stream().map(TypeCreator::asNullable); + case RIGHT_SEMI: + case RIGHT_ANTI: + // these are right joins which ignore left side columns + return Stream.of(); + case RIGHT_MARK: + // right mark join keeps all fields from right and adds a boolean mark field + return Stream.of(TypeCreator.REQUIRED.BOOLEAN); + default: + return getLeft().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + case LEFT_SINGLE: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case SEMI: + case ANTI: + case LEFT_SEMI: + case LEFT_ANTI: + // these are left joins which ignore right side columns + return Stream.of(); + case LEFT_MARK: + // left mark join keeps all fields from left and adds a boolean mark field + return Stream.of(TypeCreator.REQUIRED.BOOLEAN); + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index c359ea420..f005a29be 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -67,77 +67,56 @@ public Plan.Root from(io.substrait.proto.RelRoot rel) { } public Rel from(io.substrait.proto.Rel rel) { - var relType = rel.getRelTypeCase(); + io.substrait.proto.Rel.RelTypeCase relType = rel.getRelTypeCase(); switch (relType) { - case READ -> { + case READ: return newRead(rel.getRead()); - } - case FILTER -> { + case FILTER: return newFilter(rel.getFilter()); - } - case FETCH -> { + case FETCH: return newFetch(rel.getFetch()); - } - case AGGREGATE -> { + case AGGREGATE: return newAggregate(rel.getAggregate()); - } - case SORT -> { + case SORT: return newSort(rel.getSort()); - } - case JOIN -> { + case JOIN: return newJoin(rel.getJoin()); - } - case SET -> { + case SET: return newSet(rel.getSet()); - } - case PROJECT -> { + case PROJECT: return newProject(rel.getProject()); - } - case EXPAND -> { + case EXPAND: return newExpand(rel.getExpand()); - } - case CROSS -> { + case CROSS: return newCross(rel.getCross()); - } - case EXTENSION_LEAF -> { + case EXTENSION_LEAF: return newExtensionLeaf(rel.getExtensionLeaf()); - } - case EXTENSION_SINGLE -> { + case EXTENSION_SINGLE: return newExtensionSingle(rel.getExtensionSingle()); - } - case EXTENSION_MULTI -> { + case EXTENSION_MULTI: return newExtensionMulti(rel.getExtensionMulti()); - } - case HASH_JOIN -> { + case HASH_JOIN: return newHashJoin(rel.getHashJoin()); - } - case MERGE_JOIN -> { + case MERGE_JOIN: return newMergeJoin(rel.getMergeJoin()); - } - case NESTED_LOOP_JOIN -> { + case NESTED_LOOP_JOIN: return newNestedLoopJoin(rel.getNestedLoopJoin()); - } - case WINDOW -> { + case WINDOW: return newConsistentPartitionWindow(rel.getWindow()); - } - case WRITE -> { + case WRITE: return newWrite(rel.getWrite()); - } - case DDL -> { + case DDL: return newDdl(rel.getDdl()); - } - case UPDATE -> { + case UPDATE: return newUpdate(rel.getUpdate()); - } - default -> { + default: throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType); - } } } protected Rel newRead(ReadRel rel) { if (rel.hasVirtualTable()) { - var virtualTable = rel.getVirtualTable(); + ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); if (virtualTable.getValuesCount() == 0) { return newEmptyScan(rel); } else { @@ -155,21 +134,20 @@ protected Rel newRead(ReadRel rel) { } protected Rel newWrite(WriteRel rel) { - var relType = rel.getWriteTypeCase(); + WriteRel.WriteTypeCase relType = rel.getWriteTypeCase(); switch (relType) { - case NAMED_TABLE -> { + case NAMED_TABLE: return newNamedWrite(rel); - } - case EXTENSION_TABLE -> { + case EXTENSION_TABLE: return newExtensionWrite(rel); - } - default -> throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); + default: + throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); } } protected NamedWrite newNamedWrite(WriteRel rel) { - var input = from(rel.getInput()); - var builder = + Rel input = from(rel.getInput()); + ImmutableNamedWrite.Builder builder = NamedWrite.builder() .input(input) .names(rel.getNamedTable().getNamesList()) @@ -186,9 +164,10 @@ protected NamedWrite newNamedWrite(WriteRel rel) { } protected Rel newExtensionWrite(WriteRel rel) { - var input = from(rel.getInput()); - var detail = detailFromWriteExtensionObject(rel.getExtensionTable().getDetail()); - var builder = + Rel input = from(rel.getInput()); + Extension.WriteExtensionObject detail = + detailFromWriteExtensionObject(rel.getExtensionTable().getDetail()); + ImmutableExtensionWrite.Builder builder = ExtensionWrite.builder() .input(input) .detail(detail) @@ -205,20 +184,19 @@ protected Rel newExtensionWrite(WriteRel rel) { } protected Rel newDdl(DdlRel rel) { - var relType = rel.getWriteTypeCase(); + DdlRel.WriteTypeCase relType = rel.getWriteTypeCase(); switch (relType) { - case NAMED_OBJECT -> { + case NAMED_OBJECT: return newNamedDdl(rel); - } - case EXTENSION_OBJECT -> { + case EXTENSION_OBJECT: return newExtensionDdl(rel); - } - default -> throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); + default: + throw new UnsupportedOperationException("Unsupported WriteTypeCase of " + relType); } } protected NamedDdl newNamedDdl(DdlRel rel) { - var tableSchema = newNamedStruct(rel.getTableSchema()); + NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); return NamedDdl.builder() .names(rel.getNamedObject().getNamesList()) .tableSchema(tableSchema) @@ -233,8 +211,9 @@ protected NamedDdl newNamedDdl(DdlRel rel) { } protected ExtensionDdl newExtensionDdl(DdlRel rel) { - var detail = detailFromDdlExtensionObject(rel.getExtensionObject().getDetail()); - var tableSchema = newNamedStruct(rel.getTableSchema()); + Extension.DdlExtensionObject detail = + detailFromDdlExtensionObject(rel.getExtensionObject().getDetail()); + NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); return ExtensionDdl.builder() .detail(detail) .tableSchema(newNamedStruct(rel.getTableSchema())) @@ -254,7 +233,8 @@ protected Optional optionalViewDefinition(DdlRel rel) { protected Expression.StructLiteral tableDefaults( io.substrait.proto.Expression.Literal.Struct struct, NamedStruct tableSchema) { - var converter = new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); return Expression.StructLiteral.builder() .fields( struct.getFieldsList().stream() @@ -264,29 +244,29 @@ protected Expression.StructLiteral tableDefaults( } protected Rel newUpdate(UpdateRel rel) { - var relType = rel.getUpdateTypeCase(); + UpdateRel.UpdateTypeCase relType = rel.getUpdateTypeCase(); switch (relType) { - case NAMED_TABLE -> { + case NAMED_TABLE: return newNamedUpdate(rel); - } - default -> throw new UnsupportedOperationException( - "Unsupported UpdateTypeCase of " + relType); + default: + throw new UnsupportedOperationException("Unsupported UpdateTypeCase of " + relType); } } protected Rel newNamedUpdate(UpdateRel rel) { - var tableSchema = newNamedStruct(rel.getTableSchema()); - var converter = new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); + NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); List transformations = new ArrayList<>(rel.getTransformationsCount()); - for (var transformation : rel.getTransformationsList()) { + for (UpdateRel.TransformExpression transformation : rel.getTransformationsList()) { transformations.add( NamedUpdate.TransformExpression.builder() .transformation(converter.from(transformation.getTransformation())) .columnTarget(transformation.getColumnTarget()) .build()); } - var builder = + ImmutableNamedUpdate.Builder builder = NamedUpdate.builder() .names(rel.getNamedTable().getNamesList()) .tableSchema(tableSchema) @@ -299,8 +279,8 @@ protected Rel newNamedUpdate(UpdateRel rel) { } protected Filter newFilter(FilterRel rel) { - var input = from(rel.getInput()); - var builder = + Rel input = from(rel.getInput()); + ImmutableFilter.Builder builder = Filter.builder() .input(input) .condition( @@ -321,7 +301,7 @@ protected NamedStruct newNamedStruct(ReadRel rel) { } protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) { - var struct = namedStruct.getStruct(); + io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); return NamedStruct.builder() .names(namedStruct.getNamesList()) .struct( @@ -336,8 +316,8 @@ protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) } protected EmptyScan newEmptyScan(ReadRel rel) { - var namedStruct = newNamedStruct(rel); - var builder = + NamedStruct namedStruct = newNamedStruct(rel); + ImmutableEmptyScan.Builder builder = EmptyScan.builder() .initialSchema(namedStruct) .bestEffortFilter( @@ -367,7 +347,7 @@ protected EmptyScan newEmptyScan(ReadRel rel) { protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) { Extension.LeafRelDetail detail = detailFromExtensionLeafRel(rel.getDetail()); - var builder = + ImmutableExtensionLeaf.Builder builder = ExtensionLeaf.from(detail) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -378,7 +358,7 @@ protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) { protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) { Extension.SingleRelDetail detail = detailFromExtensionSingleRel(rel.getDetail()); Rel input = from(rel.getInput()); - var builder = + ImmutableExtensionSingle.Builder builder = ExtensionSingle.from(detail, input) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -389,7 +369,7 @@ protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) { protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) { Extension.MultiRelDetail detail = detailFromExtensionMultiRel(rel.getDetail()); List inputs = rel.getInputsList().stream().map(this::from).collect(Collectors.toList()); - var builder = + ImmutableExtensionMulti.Builder builder = ExtensionMulti.from(detail, inputs) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -401,8 +381,8 @@ protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) { } protected NamedScan newNamedScan(ReadRel rel) { - var namedStruct = newNamedStruct(rel); - var builder = + NamedStruct namedStruct = newNamedStruct(rel); + ImmutableNamedScan.Builder builder = NamedScan.builder() .initialSchema(namedStruct) .names(rel.getNamedTable().getNamesList()) @@ -434,7 +414,7 @@ protected NamedScan newNamedScan(ReadRel rel) { protected ExtensionTable newExtensionTable(ReadRel rel) { Extension.ExtensionTableDetail detail = detailFromExtensionTable(rel.getExtensionTable().getDetail()); - var builder = ExtensionTable.from(detail); + ImmutableExtensionTable.Builder builder = ExtensionTable.from(detail); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -447,9 +427,9 @@ protected ExtensionTable newExtensionTable(ReadRel rel) { } protected LocalFiles newLocalFiles(ReadRel rel) { - var namedStruct = newNamedStruct(rel); + NamedStruct namedStruct = newNamedStruct(rel); - var builder = + ImmutableLocalFiles.Builder builder = LocalFiles.builder() .initialSchema(namedStruct) .addAllItems( @@ -482,7 +462,7 @@ protected LocalFiles newLocalFiles(ReadRel rel) { } protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { - var builder = + io.substrait.relation.files.ImmutableFileOrFiles.Builder builder = FileOrFiles.builder() .partitionIndex(file.getPartitionIndex()) .start(file.getStart()) @@ -496,13 +476,14 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { } else if (file.hasDwrf()) { builder.fileFormat(FileFormat.DwrfReadOptions.builder().build()); } else if (file.hasText()) { - var ffBuilder = - FileFormat.DelimiterSeparatedTextReadOptions.builder() - .fieldDelimiter(file.getText().getFieldDelimiter()) - .maxLineSize(file.getText().getMaxLineSize()) - .quote(file.getText().getQuote()) - .headerLinesToSkip(file.getText().getHeaderLinesToSkip()) - .escape(file.getText().getEscape()); + io.substrait.relation.files.ImmutableFileFormat.DelimiterSeparatedTextReadOptions.Builder + ffBuilder = + FileFormat.DelimiterSeparatedTextReadOptions.builder() + .fieldDelimiter(file.getText().getFieldDelimiter()) + .maxLineSize(file.getText().getMaxLineSize()) + .quote(file.getText().getQuote()) + .headerLinesToSkip(file.getText().getHeaderLinesToSkip()) + .escape(file.getText().getEscape()); if (file.getText().hasValueTreatedAsNull()) { ffBuilder.valueTreatedAsNull(file.getText().getValueTreatedAsNull()); } @@ -523,12 +504,12 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { } protected VirtualTableScan newVirtualTable(ReadRel rel) { - var virtualTable = rel.getVirtualTable(); - var virtualTableSchema = newNamedStruct(rel); - var converter = + ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); + NamedStruct virtualTableSchema = newNamedStruct(rel); + ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this); List structLiterals = new ArrayList<>(virtualTable.getValuesCount()); - for (var struct : virtualTable.getValuesList()) { + for (io.substrait.proto.Expression.Literal.Struct struct : virtualTable.getValuesList()) { structLiterals.add( Expression.StructLiteral.builder() .fields( @@ -538,7 +519,7 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { .build()); } - var builder = + ImmutableVirtualTableScan.Builder builder = VirtualTableScan.builder() .bestEffortFilter( Optional.ofNullable( @@ -558,8 +539,8 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { } protected Fetch newFetch(FetchRel rel) { - var input = from(rel.getInput()); - var builder = Fetch.builder().input(input).offset(rel.getOffset()); + Rel input = from(rel.getInput()); + ImmutableFetch.Builder builder = Fetch.builder().input(input).offset(rel.getOffset()); if (rel.getCount() != -1) { // -1 is used as a sentinel value to signal LIMIT ALL // count only needs to be set when it is not -1 @@ -577,9 +558,10 @@ protected Fetch newFetch(FetchRel rel) { } protected Project newProject(ProjectRel rel) { - var input = from(rel.getInput()); - var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var builder = + Rel input = from(rel.getInput()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + ImmutableProject.Builder builder = Project.builder() .input(input) .expressions( @@ -598,28 +580,33 @@ protected Project newProject(ProjectRel rel) { } protected Expand newExpand(ExpandRel rel) { - var input = from(rel.getInput()); - var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var builder = + Rel input = from(rel.getInput()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + ImmutableExpand.Builder builder = Expand.builder() .input(input) .fields( rel.getFieldsList().stream() .map( - expandField -> - switch (expandField.getFieldTypeCase()) { - case CONSISTENT_FIELD -> Expand.ConsistentField.builder() + expandField -> { + switch (expandField.getFieldTypeCase()) { + case CONSISTENT_FIELD: + return Expand.ConsistentField.builder() .expression(converter.from(expandField.getConsistentField())) .build(); - case SWITCHING_FIELD -> Expand.SwitchingField.builder() + case SWITCHING_FIELD: + return Expand.SwitchingField.builder() .duplicates( expandField.getSwitchingField().getDuplicatesList().stream() .map(converter::from) .collect(java.util.stream.Collectors.toList())) .build(); - case FIELDTYPE_NOT_SET -> throw new UnsupportedOperationException( - "Expand fields not set"); - }) + case FIELDTYPE_NOT_SET: + default: + throw new UnsupportedOperationException("Expand fields not set"); + } + }) .collect(java.util.stream.Collectors.toList())); builder @@ -630,14 +617,14 @@ protected Expand newExpand(ExpandRel rel) { } protected Aggregate newAggregate(AggregateRel rel) { - var input = from(rel.getInput()); - var protoExprConverter = + Rel input = from(rel.getInput()); + ProtoExpressionConverter protoExprConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var protoAggrFuncConverter = + ProtoAggregateFunctionConverter protoAggrFuncConverter = new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter); List groupings = new ArrayList<>(rel.getGroupingsCount()); - for (var grouping : rel.getGroupingsList()) { + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { groupings.add( Aggregate.Grouping.builder() .expressions( @@ -647,11 +634,12 @@ protected Aggregate newAggregate(AggregateRel rel) { .build()); } List measures = new ArrayList<>(rel.getMeasuresCount()); - var pF = new FunctionArg.ProtoFrom(protoExprConverter, protoTypeConverter); - for (var measure : rel.getMeasuresList()) { - var func = measure.getMeasure(); - var funcDecl = lookup.getAggregateFunction(func.getFunctionReference(), extensions); - var args = + FunctionArg.ProtoFrom pF = new FunctionArg.ProtoFrom(protoExprConverter, protoTypeConverter); + for (AggregateRel.Measure measure : rel.getMeasuresList()) { + io.substrait.proto.AggregateFunction func = measure.getMeasure(); + SimpleExtension.AggregateFunctionVariant funcDecl = + lookup.getAggregateFunction(func.getFunctionReference(), extensions); + List args = IntStream.range(0, measure.getMeasure().getArgumentsCount()) .mapToObj(i -> pF.convert(funcDecl, i, measure.getMeasure().getArguments(i))) .collect(java.util.stream.Collectors.toList()); @@ -663,7 +651,8 @@ protected Aggregate newAggregate(AggregateRel rel) { measure.hasFilter() ? protoExprConverter.from(measure.getFilter()) : null)) .build()); } - var builder = Aggregate.builder().input(input).groupings(groupings).measures(measures); + ImmutableAggregate.Builder builder = + Aggregate.builder().input(input).groupings(groupings).measures(measures); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -676,9 +665,10 @@ protected Aggregate newAggregate(AggregateRel rel) { } protected Sort newSort(SortRel rel) { - var input = from(rel.getInput()); - var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var builder = + Rel input = from(rel.getInput()); + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + ImmutableSort.Builder builder = Sort.builder() .input(input) .sortFields( @@ -707,8 +697,9 @@ protected Join newJoin(JoinRel rel) { Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + ImmutableJoin.Builder builder = Join.builder() .left(left) .right(right) @@ -731,7 +722,7 @@ protected Join newJoin(JoinRel rel) { protected Rel newCross(CrossRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - var builder = Cross.builder().left(left).right(right); + ImmutableCross.Builder builder = Cross.builder().left(left).right(right); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -747,7 +738,8 @@ protected Set newSet(SetRel rel) { rel.getInputsList().stream() .map(inputRel -> from(inputRel)) .collect(java.util.stream.Collectors.toList()); - var builder = Set.builder().inputs(inputs).setOp(Set.SetOp.fromProto(rel.getOp())); + ImmutableSet.Builder builder = + Set.builder().inputs(inputs).setOp(Set.SetOp.fromProto(rel.getOp())); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -762,16 +754,19 @@ protected Set newSet(SetRel rel) { protected Rel newHashJoin(HashJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - var leftKeys = rel.getLeftKeysList(); - var rightKeys = rel.getRightKeysList(); + List leftKeys = rel.getLeftKeysList(); + List rightKeys = rel.getRightKeysList(); Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); - var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); - var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter leftConverter = + new ProtoExpressionConverter(lookup, extensions, leftStruct, this); + ProtoExpressionConverter rightConverter = + new ProtoExpressionConverter(lookup, extensions, rightStruct, this); + ProtoExpressionConverter unionConverter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + io.substrait.relation.physical.ImmutableHashJoin.Builder builder = HashJoin.builder() .left(left) .right(right) @@ -794,16 +789,19 @@ protected Rel newHashJoin(HashJoinRel rel) { protected Rel newMergeJoin(MergeJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - var leftKeys = rel.getLeftKeysList(); - var rightKeys = rel.getRightKeysList(); + List leftKeys = rel.getLeftKeysList(); + List rightKeys = rel.getRightKeysList(); Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); - var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); - var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter leftConverter = + new ProtoExpressionConverter(lookup, extensions, leftStruct, this); + ProtoExpressionConverter rightConverter = + new ProtoExpressionConverter(lookup, extensions, rightStruct, this); + ProtoExpressionConverter unionConverter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + io.substrait.relation.physical.ImmutableMergeJoin.Builder builder = MergeJoin.builder() .left(left) .right(right) @@ -830,8 +828,9 @@ protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = + ProtoExpressionConverter converter = + new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + io.substrait.relation.physical.ImmutableNestedLoopJoin.Builder builder = NestedLoopJoin.builder() .left(left) .right(right) @@ -855,24 +854,24 @@ protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { protected ConsistentPartitionWindow newConsistentPartitionWindow( ConsistentPartitionWindowRel rel) { - var input = from(rel.getInput()); - var protoExpressionConverter = + Rel input = from(rel.getInput()); + ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - var partitionExprs = + List partitionExprs = rel.getPartitionExpressionsList().stream() .map(protoExpressionConverter::from) .collect(Collectors.toList()); - var sortFields = + List sortFields = rel.getSortsList().stream() .map(protoExpressionConverter::fromSortField) .collect(Collectors.toList()); - var windowRelFunctions = + List windowRelFunctions = rel.getWindowFunctionsList().stream() .map(protoExpressionConverter::fromWindowRelFunction) .collect(Collectors.toList()); - var builder = + ImmutableConsistentPartitionWindow.Builder builder = ConsistentPartitionWindow.builder() .input(input) .partitionExpressions(partitionExprs) @@ -896,8 +895,9 @@ protected static Optional optionalRelmap(io.substrait.proto.RelCommon protected static Optional optionalHint(io.substrait.proto.RelCommon relCommon) { if (!relCommon.hasHint()) return Optional.empty(); - var hint = relCommon.getHint(); - var builder = Hint.builder().addAllOutputNames(hint.getOutputNamesList()); + io.substrait.proto.RelCommon.Hint hint = relCommon.getHint(); + io.substrait.hint.ImmutableHint.Builder builder = + Hint.builder().addAllOutputNames(hint.getOutputNamesList()); if (!hint.getAlias().isEmpty()) { builder.alias(hint.getAlias()); } @@ -914,7 +914,7 @@ protected Optional optionalAdvancedExtension( protected AdvancedExtension advancedExtension( io.substrait.proto.AdvancedExtension advancedExtension) { - var builder = AdvancedExtension.builder(); + io.substrait.extension.ImmutableAdvancedExtension.Builder builder = AdvancedExtension.builder(); if (advancedExtension.hasEnhancement()) { builder.enhancement(enhancementFromAdvancedExtension(advancedExtension.getEnhancement())); } diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 87d99ba72..f5a3e8081 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -46,9 +46,11 @@ protected ExpressionCopyOnWriteVisitor getExpressionCopyOnWriteVisitor() { @Override public Optional visit(Aggregate aggregate, EmptyVisitationContext context) throws E { - var input = aggregate.getInput().accept(this, context); - var groupings = transformList(aggregate.getGroupings(), context, this::visitGrouping); - var measures = transformList(aggregate.getMeasures(), context, this::visitMeasure); + Optional input = aggregate.getInput().accept(this, context); + Optional> groupings = + transformList(aggregate.getGroupings(), context, this::visitGrouping); + Optional> measures = + transformList(aggregate.getMeasures(), context, this::visitMeasure); if (allEmpty(input, groupings, measures)) { return Optional.empty(); @@ -70,8 +72,10 @@ protected Optional visitGrouping( protected Optional visitMeasure( Aggregate.Measure measure, EmptyVisitationContext context) throws E { - var preMeasureFilter = visitOptionalExpression(measure.getPreMeasureFilter(), context); - var afi = visitAggregateFunction(measure.getFunction(), context); + Optional preMeasureFilter = + visitOptionalExpression(measure.getPreMeasureFilter(), context); + Optional afi = + visitAggregateFunction(measure.getFunction(), context); if (allEmpty(preMeasureFilter, afi)) { return Optional.empty(); @@ -86,8 +90,9 @@ protected Optional visitMeasure( protected Optional visitAggregateFunction( AggregateFunctionInvocation afi, EmptyVisitationContext context) throws E { - var arguments = visitFunctionArguments(afi.arguments(), context); - var sort = transformList(afi.sort(), context, this::visitSortField); + Optional> arguments = visitFunctionArguments(afi.arguments(), context); + Optional> sort = + transformList(afi.sort(), context, this::visitSortField); if (allEmpty(arguments, sort)) { return Optional.empty(); @@ -124,8 +129,9 @@ public Optional visit(Fetch fetch, EmptyVisitationContext context) throws E @Override public Optional visit(Filter filter, EmptyVisitationContext context) throws E { - var input = filter.getInput().accept(this, context); - var condition = filter.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); + Optional input = filter.getInput().accept(this, context); + Optional condition = + filter.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(input, condition)) { return Optional.empty(); @@ -140,10 +146,10 @@ public Optional visit(Filter filter, EmptyVisitationContext context) throws @Override public Optional visit(Join join, EmptyVisitationContext context) throws E { - var left = join.getLeft().accept(this, context); - var right = join.getRight().accept(this, context); - var condition = visitOptionalExpression(join.getCondition(), context); - var postFilter = visitOptionalExpression(join.getPostJoinFilter(), context); + Optional left = join.getLeft().accept(this, context); + Optional right = join.getRight().accept(this, context); + Optional condition = visitOptionalExpression(join.getCondition(), context); + Optional postFilter = visitOptionalExpression(join.getPostJoinFilter(), context); if (allEmpty(left, right, condition, postFilter)) { return Optional.empty(); @@ -166,7 +172,7 @@ public Optional visit(Set set, EmptyVisitationContext context) throws E { @Override public Optional visit(NamedScan namedScan, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(namedScan.getFilter(), context); + Optional filter = visitOptionalExpression(namedScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -177,7 +183,7 @@ public Optional visit(NamedScan namedScan, EmptyVisitationContext context) @Override public Optional visit(LocalFiles localFiles, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(localFiles.getFilter(), context); + Optional filter = visitOptionalExpression(localFiles.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -188,8 +194,8 @@ public Optional visit(LocalFiles localFiles, EmptyVisitationContext context @Override public Optional visit(Project project, EmptyVisitationContext context) throws E { - var input = project.getInput().accept(this, context); - var expressions = visitExprList(project.getExpressions(), context); + Optional input = project.getInput().accept(this, context); + Optional> expressions = visitExprList(project.getExpressions(), context); if (allEmpty(input, expressions)) { return Optional.empty(); @@ -234,8 +240,9 @@ public Optional visit(NamedUpdate update, EmptyVisitationContext context) t @Override public Optional visit(Sort sort, EmptyVisitationContext context) throws E { - var input = sort.getInput().accept(this, context); - var sortFields = transformList(sort.getSortFields(), context, this::visitSortField); + Optional input = sort.getInput().accept(this, context); + Optional> sortFields = + transformList(sort.getSortFields(), context, this::visitSortField); if (allEmpty(input, sortFields)) { return Optional.empty(); @@ -250,8 +257,8 @@ public Optional visit(Sort sort, EmptyVisitationContext context) throws E { @Override public Optional visit(Cross cross, EmptyVisitationContext context) throws E { - var left = cross.getLeft().accept(this, context); - var right = cross.getRight().accept(this, context); + Optional left = cross.getLeft().accept(this, context); + Optional right = cross.getRight().accept(this, context); if (allEmpty(left, right)) { return Optional.empty(); @@ -267,7 +274,7 @@ public Optional visit(Cross cross, EmptyVisitationContext context) throws E @Override public Optional visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(virtualTableScan.getFilter(), context); + Optional filter = visitOptionalExpression(virtualTableScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -303,7 +310,7 @@ public Optional visit(ExtensionMulti extensionMulti, EmptyVisitationContext @Override public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext context) throws E { - var filter = visitOptionalExpression(extensionTable.getFilter(), context); + Optional filter = visitOptionalExpression(extensionTable.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -317,11 +324,14 @@ public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext @Override public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) throws E { - var left = hashJoin.getLeft().accept(this, context); - var right = hashJoin.getRight().accept(this, context); - var leftKeys = transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference); - var rightKeys = transformList(hashJoin.getRightKeys(), context, this::visitFieldReference); - var postFilter = visitOptionalExpression(hashJoin.getPostJoinFilter(), context); + Optional left = hashJoin.getLeft().accept(this, context); + Optional right = hashJoin.getRight().accept(this, context); + Optional> leftKeys = + transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference); + Optional> rightKeys = + transformList(hashJoin.getRightKeys(), context, this::visitFieldReference); + Optional postFilter = + visitOptionalExpression(hashJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { return Optional.empty(); @@ -339,11 +349,14 @@ public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) th @Override public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws E { - var left = mergeJoin.getLeft().accept(this, context); - var right = mergeJoin.getRight().accept(this, context); - var leftKeys = transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference); - var rightKeys = transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference); - var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter(), context); + Optional left = mergeJoin.getLeft().accept(this, context); + Optional right = mergeJoin.getRight().accept(this, context); + Optional> leftKeys = + transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference); + Optional> rightKeys = + transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference); + Optional postFilter = + visitOptionalExpression(mergeJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { return Optional.empty(); @@ -362,9 +375,9 @@ public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) @Override public Optional visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) throws E { - var left = nestedLoopJoin.getLeft().accept(this, context); - var right = nestedLoopJoin.getRight().accept(this, context); - var condition = + Optional left = nestedLoopJoin.getLeft().accept(this, context); + Optional right = nestedLoopJoin.getRight().accept(this, context); + Optional condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(left, right, condition)) { @@ -383,15 +396,16 @@ public Optional visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext public Optional visit( ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) throws E { - var windowFunctions = + Optional> windowFunctions = transformList( consistentPartitionWindow.getWindowFunctions(), context, this::visitWindowRelFunction); - var partitionExpressions = + Optional> partitionExpressions = transformList( consistentPartitionWindow.getPartitionExpressions(), context, (t, c) -> t.accept(getExpressionCopyOnWriteVisitor(), c)); - var sorts = transformList(consistentPartitionWindow.getSorts(), context, this::visitSortField); + Optional> sorts = + transformList(consistentPartitionWindow.getSorts(), context, this::visitSortField); if (allEmpty(windowFunctions, partitionExpressions, sorts)) { return Optional.empty(); @@ -411,7 +425,8 @@ protected Optional visitW ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation, EmptyVisitationContext context) throws E { - var functionArgs = visitFunctionArguments(windowRelFunctionInvocation.arguments(), context); + Optional> functionArgs = + visitFunctionArguments(windowRelFunctionInvocation.arguments(), context); if (allEmpty(functionArgs)) { return Optional.empty(); @@ -433,7 +448,8 @@ protected Optional> visitExprList( public Optional visitFieldReference( FieldReference fieldReference, EmptyVisitationContext context) throws E { - var inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); + Optional inputExpression = + visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { return Optional.empty(); } @@ -447,12 +463,13 @@ protected Optional> visitFunctionArguments( funcArgs, context, (arg, c) -> { - if (arg instanceof Expression expr) { - return expr.accept(getExpressionCopyOnWriteVisitor(), c) + if (arg instanceof Expression) { + return ((Expression) arg) + .accept(getExpressionCopyOnWriteVisitor(), c) .flatMap(Optional::of); - } else { - return Optional.empty(); } + + return Optional.empty(); }); } diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index f75d2ab9d..54fdaa433 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ExpressionProtoConverter.BoundConverter; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan; import io.substrait.proto.AggregateFunction; import io.substrait.proto.AggregateRel; @@ -111,7 +112,7 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel @Override public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { - var builder = + AggregateRel.Builder builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) @@ -125,11 +126,13 @@ public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws Run } private AggregateRel.Measure toProto(Aggregate.Measure measure) { - var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - var args = measure.getFunction().arguments(); - var aggFuncDef = measure.getFunction().declaration(); + FunctionArg.FuncArgVisitor< + io.substrait.proto.FunctionArgument, EmptyVisitationContext, RuntimeException> + argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + List args = measure.getFunction().arguments(); + SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); - var func = + AggregateFunction.Builder func = AggregateFunction.newBuilder() .setPhase(measure.getFunction().aggregationPhase().toProto()) .setInvocation(measure.getFunction().invocation().toProto()) @@ -149,7 +152,7 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .map(ExpressionProtoConverter::from) .collect(Collectors.toList())); - var builder = AggregateRel.Measure.newBuilder().setMeasure(func); + AggregateRel.Measure.Builder builder = AggregateRel.Measure.newBuilder().setMeasure(func); measure.getPreMeasureFilter().ifPresent(f -> builder.setFilter(toProto(f))); return builder.build(); @@ -175,7 +178,7 @@ public Rel visit(EmptyScan emptyScan, EmptyVisitationContext context) throws Run @Override public Rel visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { - var builder = + FetchRel.Builder builder = FetchRel.newBuilder() .setCommon(common(fetch)) .setInput(toProto(fetch.getInput())) @@ -189,7 +192,7 @@ public Rel visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeExce @Override public Rel visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { - var builder = + FilterRel.Builder builder = FilterRel.newBuilder() .setCommon(common(filter)) .setInput(toProto(filter.getInput())) @@ -201,7 +204,7 @@ public Rel visit(Filter filter, EmptyVisitationContext context) throws RuntimeEx @Override public Rel visit(Join join, EmptyVisitationContext context) throws RuntimeException { - var builder = + JoinRel.Builder builder = JoinRel.newBuilder() .setCommon(common(join)) .setLeft(toProto(join.getLeft())) @@ -218,7 +221,8 @@ public Rel visit(Join join, EmptyVisitationContext context) throws RuntimeExcept @Override public Rel visit(Set set, EmptyVisitationContext context) throws RuntimeException { - var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); + SetRel.Builder builder = + SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); set.getInputs() .forEach( inputRel -> { @@ -231,7 +235,7 @@ public Rel visit(Set set, EmptyVisitationContext context) throws RuntimeExceptio @Override public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws RuntimeException { - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(namedScan)) .setNamedTable(ReadRel.NamedTable.newBuilder().addAllNames(namedScan.getNames())) @@ -246,7 +250,7 @@ public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws Run @Override public Rel visit(LocalFiles localFiles, EmptyVisitationContext context) throws RuntimeException { - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(localFiles)) .setLocalFiles( @@ -269,7 +273,7 @@ public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) throws RuntimeException { ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder().setDetail(extensionTable.getDetail().toProto(this)); - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(extensionTable)) .setBaseSchema(extensionTable.getInitialSchema().toProto(typeProtoConverter)) @@ -281,7 +285,7 @@ public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) @Override public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { - var builder = + HashJoinRel.Builder builder = HashJoinRel.newBuilder() .setCommon(common(hashJoin)) .setLeft(toProto(hashJoin.getLeft())) @@ -306,7 +310,7 @@ public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws Runti @Override public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws RuntimeException { - var builder = + MergeJoinRel.Builder builder = MergeJoinRel.newBuilder() .setCommon(common(mergeJoin)) .setLeft(toProto(mergeJoin.getLeft())) @@ -332,7 +336,7 @@ public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws Run @Override public Rel visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) throws RuntimeException { - var builder = + NestedLoopJoinRel.Builder builder = NestedLoopJoinRel.newBuilder() .setCommon(common(nestedLoopJoin)) .setLeft(toProto(nestedLoopJoin.getLeft())) @@ -348,7 +352,7 @@ public Rel visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) public Rel visit( ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) throws RuntimeException { - var builder = + ConsistentPartitionWindowRel.Builder builder = ConsistentPartitionWindowRel.newBuilder() .setCommon(common(consistentPartitionWindow)) .setInput(toProto(consistentPartitionWindow.getInput())) @@ -367,7 +371,7 @@ public Rel visit( @Override public Rel visit(NamedWrite write, EmptyVisitationContext context) throws RuntimeException { - var builder = + WriteRel.Builder builder = WriteRel.newBuilder() .setCommon(common(write)) .setInput(toProto(write.getInput())) @@ -382,7 +386,7 @@ public Rel visit(NamedWrite write, EmptyVisitationContext context) throws Runtim @Override public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws RuntimeException { - var builder = + WriteRel.Builder builder = WriteRel.newBuilder() .setCommon(common(write)) .setInput(toProto(write.getInput())) @@ -398,7 +402,7 @@ public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws Ru @Override public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeException { - var builder = + DdlRel.Builder builder = DdlRel.newBuilder() .setCommon(common(ddl)) .setTableSchema(ddl.getTableSchema().toProto(typeProtoConverter)) @@ -415,7 +419,7 @@ public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeExc @Override public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws RuntimeException { - var builder = + DdlRel.Builder builder = DdlRel.newBuilder() .setCommon(common(ddl)) .setTableSchema(ddl.getTableSchema().toProto(typeProtoConverter)) @@ -433,7 +437,7 @@ public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws Runtim @Override public Rel visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { - var builder = + UpdateRel.Builder builder = UpdateRel.newBuilder() .setNamedTable(NamedTable.newBuilder().addAllNames(update.getNames())) .setTableSchema(update.getTableSchema().toProto(typeProtoConverter)) @@ -460,11 +464,13 @@ private List toProtoWindowRelFun return windowRelFunctionInvocations.stream() .map( f -> { - var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - var args = f.arguments(); - var aggFuncDef = f.declaration(); + FunctionArg.FuncArgVisitor< + io.substrait.proto.FunctionArgument, EmptyVisitationContext, RuntimeException> + argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + List args = f.arguments(); + SimpleExtension.WindowFunctionVariant aggFuncDef = f.declaration(); - var arguments = + List arguments = IntStream.range(0, args.size()) .mapToObj( i -> @@ -472,7 +478,7 @@ private List toProtoWindowRelFun .accept( aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) .collect(Collectors.toList()); - var options = + List options = f.options().stream() .map(ExpressionProtoConverter::from) .collect(Collectors.toList()); @@ -494,7 +500,7 @@ private List toProtoWindowRelFun @Override public Rel visit(Project project, EmptyVisitationContext context) throws RuntimeException { - var builder = + ProjectRel.Builder builder = ProjectRel.newBuilder() .setCommon(common(project)) .setInput(toProto(project.getInput())) @@ -506,20 +512,22 @@ public Rel visit(Project project, EmptyVisitationContext context) throws Runtime @Override public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeException { - var builder = + ExpandRel.Builder builder = ExpandRel.newBuilder().setCommon(common(expand)).setInput(toProto(expand.getInput())); expand .getFields() .forEach( expandField -> { - if (expandField instanceof Expand.ConsistentField cf) { + if (expandField instanceof Expand.ConsistentField) { + Expand.ConsistentField cf = (Expand.ConsistentField) expandField; builder.addFields( ExpandRel.ExpandField.newBuilder() .setConsistentField(toProto(cf.getExpression())) .build()); - } else if (expandField instanceof Expand.SwitchingField sf) { + } else if (expandField instanceof Expand.SwitchingField) { + Expand.SwitchingField sf = (Expand.SwitchingField) expandField; builder.addFields( ExpandRel.ExpandField.newBuilder() .setSwitchingField( @@ -536,7 +544,7 @@ public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeEx @Override public Rel visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { - var builder = + SortRel.Builder builder = SortRel.newBuilder() .setCommon(common(sort)) .setInput(toProto(sort.getInput())) @@ -548,7 +556,7 @@ public Rel visit(Sort sort, EmptyVisitationContext context) throws RuntimeExcept @Override public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { - var builder = + CrossRel.Builder builder = CrossRel.newBuilder() .setCommon(common(cross)) .setLeft(toProto(cross.getLeft())) @@ -561,7 +569,7 @@ public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeExce @Override public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) throws RuntimeException { - var builder = + ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(virtualTableScan)) .setVirtualTable( @@ -584,7 +592,7 @@ public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext conte @Override public Rel visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) throws RuntimeException { - var builder = + ExtensionLeafRel.Builder builder = ExtensionLeafRel.newBuilder() .setCommon(common(extensionLeaf)) .setDetail(extensionLeaf.getDetail().toProto(this)); @@ -594,7 +602,7 @@ public Rel visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) @Override public Rel visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) throws RuntimeException { - var builder = + ExtensionSingleRel.Builder builder = ExtensionSingleRel.newBuilder() .setCommon(common(extensionSingle)) .setInput(toProto(extensionSingle.getInput())) @@ -607,7 +615,7 @@ public Rel visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) throws RuntimeException { List inputs = extensionMulti.getInputs().stream().map(this::toProto).collect(Collectors.toList()); - var builder = + ExtensionMultiRel.Builder builder = ExtensionMultiRel.newBuilder() .setCommon(common(extensionMulti)) .addAllInputs(inputs) @@ -616,11 +624,11 @@ public Rel visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) } private RelCommon common(io.substrait.relation.Rel rel) { - var builder = RelCommon.newBuilder(); + RelCommon.Builder builder = RelCommon.newBuilder(); rel.getCommonExtension() .ifPresent(extension -> builder.setAdvancedExtension(extension.toProto(this))); - var remap = rel.getRemap().orElse(null); + io.substrait.relation.Rel.Remap remap = rel.getRemap().orElse(null); if (remap != null) { builder.setEmit(RelCommon.Emit.newBuilder().addAllOutputMapping(remap.indices())); } else { diff --git a/core/src/main/java/io/substrait/relation/Set.java b/core/src/main/java/io/substrait/relation/Set.java index 697cda1be..dcf29ca6b 100644 --- a/core/src/main/java/io/substrait/relation/Set.java +++ b/core/src/main/java/io/substrait/relation/Set.java @@ -35,7 +35,7 @@ public SetRel.SetOp toProto() { } public static SetOp fromProto(SetRel.SetOp proto) { - for (var v : values()) { + for (SetOp v : values()) { if (v.proto == proto) { return v; } @@ -66,14 +66,24 @@ protected Type.Struct deriveRecordType() { } // As defined in https://substrait.io/relations/logical_relations/#set-operation-types - return switch (getSetOp()) { - case UNKNOWN -> first; // alternative would be to throw an exception - case MINUS_PRIMARY, MINUS_PRIMARY_ALL, MINUS_MULTISET -> first; - case INTERSECTION_PRIMARY -> coalesceNullabilityIntersectionPrimary(first, rest); - case INTERSECTION_MULTISET, INTERSECTION_MULTISET_ALL -> coalesceNullabilityIntersection( - first, rest); - case UNION_DISTINCT, UNION_ALL -> coalesceNullabilityUnion(first, rest); - }; + switch (getSetOp()) { + case UNKNOWN: + return first; // alternative would be to throw an exception + case MINUS_PRIMARY: + case MINUS_PRIMARY_ALL: + case MINUS_MULTISET: + return first; + case INTERSECTION_PRIMARY: + return coalesceNullabilityIntersectionPrimary(first, rest); + case INTERSECTION_MULTISET: + case INTERSECTION_MULTISET_ALL: + return coalesceNullabilityIntersection(first, rest); + case UNION_DISTINCT: + case UNION_ALL: + return coalesceNullabilityUnion(first, rest); + default: + throw new UnsupportedOperationException("Unexpected set operation: " + getSetOp()); + } } /** If field is nullable in any of the inputs, it's nullable in the output */ diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index 6eb7a361d..0b9f61e28 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -24,11 +24,11 @@ public abstract class VirtualTableScan extends AbstractReadRel { */ @Value.Check protected void check() { - var names = getInitialSchema().names(); + List names = getInitialSchema().names(); assert names.size() == NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct()); - var rows = getRows(); + List rows = getRows(); assert rows.size() > 0 && names.stream().noneMatch(s -> s == null) diff --git a/core/src/main/java/io/substrait/relation/files/FileFormat.java b/core/src/main/java/io/substrait/relation/files/FileFormat.java index ed9482d3b..415ff714e 100644 --- a/core/src/main/java/io/substrait/relation/files/FileFormat.java +++ b/core/src/main/java/io/substrait/relation/files/FileFormat.java @@ -7,35 +7,35 @@ public interface FileFormat { @Value.Immutable - abstract static class ParquetReadOptions implements FileFormat { + abstract class ParquetReadOptions implements FileFormat { public static ImmutableFileFormat.ParquetReadOptions.Builder builder() { return ImmutableFileFormat.ParquetReadOptions.builder(); } } @Value.Immutable - abstract static class ArrowReadOptions implements FileFormat { + abstract class ArrowReadOptions implements FileFormat { public static ImmutableFileFormat.ArrowReadOptions.Builder builder() { return ImmutableFileFormat.ArrowReadOptions.builder(); } } @Value.Immutable - abstract static class OrcReadOptions implements FileFormat { + abstract class OrcReadOptions implements FileFormat { public static ImmutableFileFormat.OrcReadOptions.Builder builder() { return ImmutableFileFormat.OrcReadOptions.builder(); } } @Value.Immutable - abstract static class DwrfReadOptions implements FileFormat { + abstract class DwrfReadOptions implements FileFormat { public static ImmutableFileFormat.DwrfReadOptions.Builder builder() { return ImmutableFileFormat.DwrfReadOptions.builder(); } } @Value.Immutable - abstract static class DelimiterSeparatedTextReadOptions implements FileFormat { + abstract class DelimiterSeparatedTextReadOptions implements FileFormat { public abstract String getFieldDelimiter(); public abstract long getMaxLineSize(); @@ -54,7 +54,7 @@ public static ImmutableFileFormat.DelimiterSeparatedTextReadOptions.Builder buil } @Value.Immutable - abstract static class Extension implements FileFormat { + abstract class Extension implements FileFormat { public abstract com.google.protobuf.Any getExtension(); public static ImmutableFileFormat.Extension.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java index 4d0bb421f..4b227025f 100644 --- a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java +++ b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java @@ -36,29 +36,33 @@ default ReadRel.LocalFiles.FileOrFiles toProto() { getFileFormat() .ifPresent( fileFormat -> { - if (fileFormat instanceof FileFormat.ParquetReadOptions options) { + if (fileFormat instanceof FileFormat.ParquetReadOptions) { builder.setParquet( ReadRel.LocalFiles.FileOrFiles.ParquetReadOptions.newBuilder().build()); - } else if (fileFormat instanceof FileFormat.ArrowReadOptions options) { + } else if (fileFormat instanceof FileFormat.ArrowReadOptions) { builder.setArrow( ReadRel.LocalFiles.FileOrFiles.ArrowReadOptions.newBuilder().build()); - } else if (fileFormat instanceof FileFormat.OrcReadOptions options) { + } else if (fileFormat instanceof FileFormat.OrcReadOptions) { builder.setOrc(ReadRel.LocalFiles.FileOrFiles.OrcReadOptions.newBuilder().build()); - } else if (fileFormat instanceof FileFormat.DwrfReadOptions options) { + } else if (fileFormat instanceof FileFormat.DwrfReadOptions) { builder.setDwrf( ReadRel.LocalFiles.FileOrFiles.DwrfReadOptions.newBuilder().build()); - } else if (fileFormat - instanceof FileFormat.DelimiterSeparatedTextReadOptions options) { - var optionsBuilder = - ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.newBuilder() - .setFieldDelimiter(options.getFieldDelimiter()) - .setMaxLineSize(options.getMaxLineSize()) - .setQuote(options.getQuote()) - .setHeaderLinesToSkip(options.getHeaderLinesToSkip()) - .setEscape(options.getEscape()); + } else if (fileFormat instanceof FileFormat.DelimiterSeparatedTextReadOptions) { + FileFormat.DelimiterSeparatedTextReadOptions options = + (FileFormat.DelimiterSeparatedTextReadOptions) fileFormat; + ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.Builder + optionsBuilder = + ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions + .newBuilder() + .setFieldDelimiter(options.getFieldDelimiter()) + .setMaxLineSize(options.getMaxLineSize()) + .setQuote(options.getQuote()) + .setHeaderLinesToSkip(options.getHeaderLinesToSkip()) + .setEscape(options.getEscape()); options.getValueTreatedAsNull().ifPresent(optionsBuilder::setValueTreatedAsNull); builder.setText(optionsBuilder.build()); - } else if (fileFormat instanceof FileFormat.Extension options) { + } else if (fileFormat instanceof FileFormat.Extension) { + FileFormat.Extension options = (FileFormat.Extension) fileFormat; builder.setExtension(options.getExtension()); } else { throw new UnsupportedOperationException( @@ -72,11 +76,14 @@ default ReadRel.LocalFiles.FileOrFiles toProto() { getPath() .ifPresent( path -> { - switch (pathType) { - case URI_PATH -> builder.setUriPath(path); - case URI_PATH_GLOB -> builder.setUriPathGlob(path); - case URI_FILE -> builder.setUriFile(path); - case URI_FOLDER -> builder.setUriFolder(path); + if (pathType == PathType.URI_PATH) { + builder.setUriPath(path); + } else if (pathType == PathType.URI_PATH_GLOB) { + builder.setUriPathGlob(path); + } else if (pathType == PathType.URI_FILE) { + builder.setUriFile(path); + } else if (pathType == PathType.URI_FOLDER) { + builder.setUriFolder(path); } })); diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index 1d9cc3c54..246afa151 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -43,7 +43,7 @@ public static enum JoinType { } public static JoinType fromProto(HashJoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -58,23 +58,37 @@ public HashJoinRel.JoinType toProto() { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case RIGHT: + case OUTER: + return getLeft().getRecordType().fields().stream().map(TypeCreator::asNullable); + case RIGHT_ANTI: + case RIGHT_SEMI: + return Stream.empty(); + default: + return getLeft().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java index 4f7facd32..34a51487a 100644 --- a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java @@ -43,7 +43,7 @@ public static enum JoinType { } public static JoinType fromProto(MergeJoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -58,23 +58,37 @@ public MergeJoinRel.JoinType toProto() { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java index 233aff176..17af10df5 100644 --- a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java @@ -40,7 +40,7 @@ public NestedLoopJoinRel.JoinType toProto() { } public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { - for (var v : values()) { + for (JoinType v : values()) { if (v.proto == proto) { return v; } @@ -52,23 +52,37 @@ public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); - default -> getRight().getRecordType().fields().stream(); - }; + Stream leftTypes = getLeftTypes(); + Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } + private Stream getLeftTypes() { + switch (getJoinType()) { + case RIGHT: + case OUTER: + return getLeft().getRecordType().fields().stream().map(TypeCreator::asNullable); + case RIGHT_ANTI: + case RIGHT_SEMI: + return Stream.empty(); + default: + return getLeft().getRecordType().fields().stream(); + } + } + + private Stream getRightTypes() { + switch (getJoinType()) { + case LEFT: + case OUTER: + return getRight().getRecordType().fields().stream().map(TypeCreator::asNullable); + case LEFT_ANTI: + case LEFT_SEMI: + return Stream.empty(); + default: + return getRight().getRecordType().fields().stream(); + } + } + @Override public O accept( RelVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/type/Deserializers.java b/core/src/main/java/io/substrait/type/Deserializers.java index 8b0409085..e49b4936c 100644 --- a/core/src/main/java/io/substrait/type/Deserializers.java +++ b/core/src/main/java/io/substrait/type/Deserializers.java @@ -42,7 +42,7 @@ public ParseDeserializer( @Override public T deserialize(final JsonParser p, final DeserializationContext ctxt) throws IOException, JsonProcessingException { - var typeString = p.getValueAsString(); + String typeString = p.getValueAsString(); try { String namespace = (String) ctxt.findInjectableValue(SimpleExtension.URI_LOCATOR_KEY, null, null); diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 5e72a551a..9c241542c 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -20,7 +20,7 @@ static NamedStruct of(Iterable names, Type.Struct type) { } default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConverter) { - var type = struct().accept(typeProtoConverter); + io.substrait.proto.Type type = struct().accept(typeProtoConverter); return io.substrait.proto.NamedStruct.newBuilder() .setStruct(type.getStruct()) .addAllNames(names()) @@ -29,7 +29,7 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve static io.substrait.type.NamedStruct fromProto( io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { - var struct = namedStruct.getStruct(); + io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() .names(namedStruct.getNamesList()) .struct( diff --git a/core/src/main/java/io/substrait/type/YamlRead.java b/core/src/main/java/io/substrait/type/YamlRead.java index 89f8a9163..c4e7f4866 100644 --- a/core/src/main/java/io/substrait/type/YamlRead.java +++ b/core/src/main/java/io/substrait/type/YamlRead.java @@ -54,7 +54,8 @@ private static Stream parse(String name) { new ObjectMapper(new YAMLFactory()) .enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY) .registerModule(Deserializers.MODULE); - var doc = mapper.readValue(new File(name), SimpleExtension.ExtensionSignatures.class); + SimpleExtension.ExtensionSignatures doc = + mapper.readValue(new File(name), SimpleExtension.ExtensionSignatures.class); logger.atDebug().log( "Parsed {} functions in file {}.", diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index d27bd03e2..83b4ec8a4 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -1,5 +1,6 @@ package io.substrait.type.parser; +import io.substrait.function.ImmutableTypeExpression; import io.substrait.function.ParameterizedType; import io.substrait.function.ParameterizedTypeCreator; import io.substrait.function.TypeExpression; @@ -8,9 +9,11 @@ import io.substrait.type.SubstraitTypeVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.List; import java.util.Locale; import java.util.function.Function; import java.util.function.IntFunction; +import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ErrorNode; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.RuleNode; @@ -19,7 +22,7 @@ public class ParseToPojo { public static Type type(String namespace, SubstraitTypeParser.StartContext ctx) { - var visitor = Visitor.simple(namespace); + Visitor visitor = Visitor.simple(namespace); return (Type) ctx.accept(visitor); } @@ -160,12 +163,12 @@ public Type visitIntervalYear(final SubstraitTypeParser.IntervalYearContext ctx) public TypeExpression visitIntervalDay(final SubstraitTypeParser.IntervalDayContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).intervalDay(p); + if (precision instanceof Integer) { + return withNull(nullable).intervalDay((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).intervalDayE(s); + return withNullP(nullable).intervalDayE((String) precision); } checkExpression(); @@ -177,12 +180,12 @@ public TypeExpression visitIntervalCompound( final SubstraitTypeParser.IntervalCompoundContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).intervalCompound(p); + if (precision instanceof Integer) { + return withNull(nullable).intervalCompound((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).intervalCompoundE(s); + return withNullP(nullable).intervalCompoundE((String) precision); } checkExpression(); @@ -196,7 +199,7 @@ public Type visitUuid(final SubstraitTypeParser.UuidContext ctx) { @Override public Type visitUserDefined(SubstraitTypeParser.UserDefinedContext ctx) { - var name = ctx.Identifier().getSymbol().getText(); + String name = ctx.Identifier().getSymbol().getText(); return withNull(ctx).userDefined(namespace, name); } @@ -280,12 +283,12 @@ public TypeExpression visitPrecisionTimestamp( final SubstraitTypeParser.PrecisionTimestampContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).precisionTimestamp(p); + if (precision instanceof Integer) { + return withNull(nullable).precisionTimestamp((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).precisionTimestampE(s); + return withNullP(nullable).precisionTimestampE((String) precision); } checkExpression(); @@ -297,12 +300,12 @@ public TypeExpression visitPrecisionTimestampTZ( final SubstraitTypeParser.PrecisionTimestampTZContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); - if (precision instanceof Integer p) { - return withNull(nullable).precisionTimestampTZ(p); + if (precision instanceof Integer) { + return withNull(nullable).precisionTimestampTZ((Integer) precision); } - if (precision instanceof String s) { + if (precision instanceof String) { checkParameterizedOrExpression(); - return withNullP(nullable).precisionTimestampTZE(s); + return withNullP(nullable).precisionTimestampTZE((String) precision); } checkExpression(); @@ -325,7 +328,7 @@ private Object i(SubstraitTypeParser.NumericParameterContext ctx) { @Override public TypeExpression visitStruct(final SubstraitTypeParser.StructContext ctx) { boolean nullable = ctx.isnull != null; - var types = + List types = ctx.expr().stream() .map(t -> t.accept(this)) .collect(java.util.stream.Collectors.toList()); @@ -456,17 +459,17 @@ public TypeExpression visitTernary(final SubstraitTypeParser.TernaryContext ctx) public TypeExpression visitMultilineDefinition( final SubstraitTypeParser.MultilineDefinitionContext ctx) { checkExpression(); - var exprs = + List exprs = ctx.expr().stream() .map(t -> t.accept(this)) .collect(java.util.stream.Collectors.toList()); - var identifiers = + List identifiers = ctx.Identifier().stream() .map(t -> t.getText()) .collect(java.util.stream.Collectors.toList()); - var finalExpr = ctx.finalType.accept(this); + TypeExpression finalExpr = ctx.finalType.accept(this); - var bldr = TypeExpression.ReturnProgram.builder(); + ImmutableTypeExpression.ReturnProgram.Builder bldr = TypeExpression.ReturnProgram.builder(); for (int i = 0; i < exprs.size(); i++) { bldr.addAssignments( TypeExpression.ReturnProgram.Assignment.builder() @@ -482,20 +485,7 @@ public TypeExpression visitMultilineDefinition( @Override public TypeExpression visitBinaryExpr(final SubstraitTypeParser.BinaryExprContext ctx) { checkExpression(); - TypeExpression.BinaryOperation.OpType type = - switch (ctx.op.getText().toUpperCase(Locale.ROOT)) { - case "+" -> TypeExpression.BinaryOperation.OpType.ADD; - case "-" -> TypeExpression.BinaryOperation.OpType.SUBTRACT; - case "*" -> TypeExpression.BinaryOperation.OpType.MULTIPLY; - case "/" -> TypeExpression.BinaryOperation.OpType.DIVIDE; - case ">" -> TypeExpression.BinaryOperation.OpType.GT; - case "<" -> TypeExpression.BinaryOperation.OpType.LT; - case "AND" -> TypeExpression.BinaryOperation.OpType.AND; - case "OR" -> TypeExpression.BinaryOperation.OpType.OR; - case "=" -> TypeExpression.BinaryOperation.OpType.EQ; - case ":=" -> TypeExpression.BinaryOperation.OpType.COVERS; - default -> throw new IllegalStateException("Unexpected value: " + ctx.op.getText()); - }; + TypeExpression.BinaryOperation.OpType type = getBinaryExpressionType(ctx.op); return TypeExpression.BinaryOperation.builder() .opType(type) .left(ctx.left.accept(this)) @@ -503,6 +493,33 @@ public TypeExpression visitBinaryExpr(final SubstraitTypeParser.BinaryExprContex .build(); } + private TypeExpression.BinaryOperation.OpType getBinaryExpressionType(Token token) { + switch (token.getText().toUpperCase(Locale.ROOT)) { + case "+": + return TypeExpression.BinaryOperation.OpType.ADD; + case "-": + return TypeExpression.BinaryOperation.OpType.SUBTRACT; + case "*": + return TypeExpression.BinaryOperation.OpType.MULTIPLY; + case "/": + return TypeExpression.BinaryOperation.OpType.DIVIDE; + case ">": + return TypeExpression.BinaryOperation.OpType.GT; + case "<": + return TypeExpression.BinaryOperation.OpType.LT; + case "AND": + return TypeExpression.BinaryOperation.OpType.AND; + case "OR": + return TypeExpression.BinaryOperation.OpType.OR; + case "=": + return TypeExpression.BinaryOperation.OpType.EQ; + case ":=": + return TypeExpression.BinaryOperation.OpType.COVERS; + default: + throw new IllegalStateException("Unexpected value: " + token.getText()); + } + } + @Override public TypeExpression visitNumericLiteral(final SubstraitTypeParser.NumericLiteralContext ctx) { return TypeExpression.IntegerLiteral.builder().value(Integer.parseInt(ctx.getText())).build(); @@ -533,14 +550,7 @@ public TypeExpression visitFunctionCall(final SubstraitTypeParser.FunctionCallCo if (ctx.expr().size() != 2) { throw new IllegalStateException("Only two argument functions exist for type expressions."); } - var name = ctx.Identifier().getSymbol().getText().toUpperCase(Locale.ROOT); - TypeExpression.BinaryOperation.OpType type = - switch (name) { - case "MIN" -> TypeExpression.BinaryOperation.OpType.MIN; - case "MAX" -> TypeExpression.BinaryOperation.OpType.MAX; - default -> throw new IllegalStateException( - "The following operation was unrecognized: " + name); - }; + TypeExpression.BinaryOperation.OpType type = getFunctionType(ctx.Identifier().getSymbol()); return TypeExpression.BinaryOperation.builder() .opType(type) .left(ctx.expr(0).accept(this)) @@ -548,6 +558,18 @@ public TypeExpression visitFunctionCall(final SubstraitTypeParser.FunctionCallCo .build(); } + private TypeExpression.BinaryOperation.OpType getFunctionType(Token token) { + switch (token.getText().toUpperCase(Locale.ROOT)) { + case "MIN": + return TypeExpression.BinaryOperation.OpType.MIN; + case "MAX": + return TypeExpression.BinaryOperation.OpType.MAX; + default: + throw new IllegalStateException( + "The following operation was unrecognized: " + token.getText()); + } + } + @Override public TypeExpression visitNotExpr(final SubstraitTypeParser.NotExprContext ctx) { return TypeExpression.NotOperation.builder().inner(ctx.expr().accept(this)).build(); diff --git a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java index df45fca8e..b22c68265 100644 --- a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java +++ b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java @@ -30,11 +30,11 @@ public static TypeExpression parseExpression(String str, String namespace) { } private static SubstraitTypeParser.StartContext parse(String str) { - var lexer = new SubstraitTypeLexer(CharStreams.fromString(str)); + SubstraitTypeLexer lexer = new SubstraitTypeLexer(CharStreams.fromString(str)); lexer.removeErrorListeners(); lexer.addErrorListener(TypeErrorListener.INSTANCE); - var tokenStream = new CommonTokenStream(lexer); - var parser = new io.substrait.type.SubstraitTypeParser(tokenStream); + CommonTokenStream tokenStream = new CommonTokenStream(lexer); + SubstraitTypeParser parser = new io.substrait.type.SubstraitTypeParser(tokenStream); parser.removeErrorListeners(); parser.addErrorListener(TypeErrorListener.INSTANCE); return parser.start(); diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 95734321e..8a37a2928 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -165,7 +165,7 @@ public final T visit(final Type.Map expr) { @Override public final T visit(final Type.UserDefined expr) { - var ref = + int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); return typeContainer(expr).userDefined(ref); } diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index baf5c411c..293b25c90 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -264,59 +264,62 @@ public ParameterizedType userDefined(int ref) { @Override protected ParameterizedType wrap(final Object o) { - var bldr = ParameterizedType.newBuilder(); - if (o instanceof Type.Boolean t) { - return bldr.setBool(t).build(); - } else if (o instanceof Type.I8 t) { - return bldr.setI8(t).build(); - } else if (o instanceof Type.I16 t) { - return bldr.setI16(t).build(); - } else if (o instanceof Type.I32 t) { - return bldr.setI32(t).build(); - } else if (o instanceof Type.I64 t) { - return bldr.setI64(t).build(); - } else if (o instanceof Type.FP32 t) { - return bldr.setFp32(t).build(); - } else if (o instanceof Type.FP64 t) { - return bldr.setFp64(t).build(); - } else if (o instanceof Type.String t) { - return bldr.setString(t).build(); - } else if (o instanceof Type.Binary t) { - return bldr.setBinary(t).build(); - } else if (o instanceof Type.Timestamp t) { - return bldr.setTimestamp(t).build(); - } else if (o instanceof Type.Date t) { - return bldr.setDate(t).build(); - } else if (o instanceof Type.Time t) { - return bldr.setTime(t).build(); - } else if (o instanceof Type.TimestampTZ t) { - return bldr.setTimestampTz(t).build(); - } else if (o instanceof Type.IntervalYear t) { - return bldr.setIntervalYear(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedIntervalDay t) { - return bldr.setIntervalDay(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedIntervalCompound t) { - return bldr.setIntervalCompound(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedFixedChar t) { - return bldr.setFixedChar(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedVarChar t) { - return bldr.setVarchar(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedFixedBinary t) { - return bldr.setFixedBinary(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedDecimal t) { - return bldr.setDecimal(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestamp t) { - return bldr.setPrecisionTimestamp(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestampTZ t) { - return bldr.setPrecisionTimestampTz(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedStruct t) { - return bldr.setStruct(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedList t) { - return bldr.setList(t).build(); - } else if (o instanceof ParameterizedType.ParameterizedMap t) { - return bldr.setMap(t).build(); - } else if (o instanceof Type.UUID t) { - return bldr.setUuid(t).build(); + ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); + if (o instanceof Type.Boolean) { + return bldr.setBool((Type.Boolean) o).build(); + } else if (o instanceof Type.I8) { + return bldr.setI8((Type.I8) o).build(); + } else if (o instanceof Type.I16) { + return bldr.setI16((Type.I16) o).build(); + } else if (o instanceof Type.I32) { + return bldr.setI32((Type.I32) o).build(); + } else if (o instanceof Type.I64) { + return bldr.setI64((Type.I64) o).build(); + } else if (o instanceof Type.FP32) { + return bldr.setFp32((Type.FP32) o).build(); + } else if (o instanceof Type.FP64) { + return bldr.setFp64((Type.FP64) o).build(); + } else if (o instanceof Type.String) { + return bldr.setString((Type.String) o).build(); + } else if (o instanceof Type.Binary) { + return bldr.setBinary((Type.Binary) o).build(); + } else if (o instanceof Type.Timestamp) { + return bldr.setTimestamp((Type.Timestamp) o).build(); + } else if (o instanceof Type.Date) { + return bldr.setDate((Type.Date) o).build(); + } else if (o instanceof Type.Time) { + return bldr.setTime((Type.Time) o).build(); + } else if (o instanceof Type.TimestampTZ) { + return bldr.setTimestampTz((Type.TimestampTZ) o).build(); + } else if (o instanceof Type.IntervalYear) { + return bldr.setIntervalYear((Type.IntervalYear) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedIntervalDay) { + return bldr.setIntervalDay((ParameterizedType.ParameterizedIntervalDay) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedIntervalCompound) { + return bldr.setIntervalCompound((ParameterizedType.ParameterizedIntervalCompound) o) + .build(); + } else if (o instanceof ParameterizedType.ParameterizedFixedChar) { + return bldr.setFixedChar((ParameterizedType.ParameterizedFixedChar) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedVarChar) { + return bldr.setVarchar((ParameterizedType.ParameterizedVarChar) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedFixedBinary) { + return bldr.setFixedBinary((ParameterizedType.ParameterizedFixedBinary) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedDecimal) { + return bldr.setDecimal((ParameterizedType.ParameterizedDecimal) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestamp) { + return bldr.setPrecisionTimestamp((ParameterizedType.ParameterizedPrecisionTimestamp) o) + .build(); + } else if (o instanceof ParameterizedType.ParameterizedPrecisionTimestampTZ) { + return bldr.setPrecisionTimestampTz((ParameterizedType.ParameterizedPrecisionTimestampTZ) o) + .build(); + } else if (o instanceof ParameterizedType.ParameterizedStruct) { + return bldr.setStruct((ParameterizedType.ParameterizedStruct) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedList) { + return bldr.setList((ParameterizedType.ParameterizedList) o).build(); + } else if (o instanceof ParameterizedType.ParameterizedMap) { + return bldr.setMap((ParameterizedType.ParameterizedMap) o).build(); + } else if (o instanceof Type.UUID) { + return bldr.setUuid((Type.UUID) o).build(); } throw new UnsupportedOperationException("Unable to wrap type of " + o.getClass()); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index e4e33bacc..661f57fea 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -18,57 +18,87 @@ public ProtoTypeConverter( } public Type from(io.substrait.proto.Type type) { - return switch (type.getKindCase()) { - case BOOL -> n(type.getBool().getNullability()).BOOLEAN; - case I8 -> n(type.getI8().getNullability()).I8; - case I16 -> n(type.getI16().getNullability()).I16; - case I32 -> n(type.getI32().getNullability()).I32; - case I64 -> n(type.getI64().getNullability()).I64; - case FP32 -> n(type.getFp32().getNullability()).FP32; - case FP64 -> n(type.getFp64().getNullability()).FP64; - case STRING -> n(type.getString().getNullability()).STRING; - case BINARY -> n(type.getBinary().getNullability()).BINARY; - case TIMESTAMP -> n(type.getTimestamp().getNullability()).TIMESTAMP; - case DATE -> n(type.getDate().getNullability()).DATE; - case TIME -> n(type.getTime().getNullability()).TIME; - case INTERVAL_YEAR -> n(type.getIntervalYear().getNullability()).INTERVAL_YEAR; - case INTERVAL_DAY -> n(type.getIntervalDay().getNullability()) - // precision defaults to 6 (micros) for backwards compatibility, see protobuf - .intervalDay( - type.getIntervalDay().hasPrecision() ? type.getIntervalDay().getPrecision() : 6); - case INTERVAL_COMPOUND -> n(type.getIntervalCompound().getNullability()) - .intervalCompound(type.getIntervalCompound().getPrecision()); - case TIMESTAMP_TZ -> n(type.getTimestampTz().getNullability()).TIMESTAMP_TZ; - case UUID -> n(type.getUuid().getNullability()).UUID; - case FIXED_CHAR -> n(type.getFixedChar().getNullability()) - .fixedChar(type.getFixedChar().getLength()); - case VARCHAR -> n(type.getVarchar().getNullability()).varChar(type.getVarchar().getLength()); - case FIXED_BINARY -> n(type.getFixedBinary().getNullability()) - .fixedBinary(type.getFixedBinary().getLength()); - case DECIMAL -> n(type.getDecimal().getNullability()) - .decimal(type.getDecimal().getPrecision(), type.getDecimal().getScale()); - case PRECISION_TIME -> n(type.getPrecisionTime().getNullability()) - .precisionTime(type.getPrecisionTime().getPrecision()); - case PRECISION_TIMESTAMP -> n(type.getPrecisionTimestamp().getNullability()) - .precisionTimestamp(type.getPrecisionTimestamp().getPrecision()); - case PRECISION_TIMESTAMP_TZ -> n(type.getPrecisionTimestampTz().getNullability()) - .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); - case STRUCT -> n(type.getStruct().getNullability()) - .struct( - type.getStruct().getTypesList().stream() - .map(this::from) - .collect(java.util.stream.Collectors.toList())); - case LIST -> fromList(type.getList()); - case MAP -> fromMap(type.getMap()); - case USER_DEFINED -> { - var userDefined = type.getUserDefined(); - var t = lookup.getType(userDefined.getTypeReference(), extensions); - yield n(userDefined.getNullability()).userDefined(t.uri(), t.name()); - } - case USER_DEFINED_TYPE_REFERENCE -> throw new UnsupportedOperationException( - "Unsupported user defined reference: " + type); - case KIND_NOT_SET -> throw new UnsupportedOperationException("Type is not set: " + type); - }; + switch (type.getKindCase()) { + case BOOL: + return n(type.getBool().getNullability()).BOOLEAN; + case I8: + return n(type.getI8().getNullability()).I8; + case I16: + return n(type.getI16().getNullability()).I16; + case I32: + return n(type.getI32().getNullability()).I32; + case I64: + return n(type.getI64().getNullability()).I64; + case FP32: + return n(type.getFp32().getNullability()).FP32; + case FP64: + return n(type.getFp64().getNullability()).FP64; + case STRING: + return n(type.getString().getNullability()).STRING; + case BINARY: + return n(type.getBinary().getNullability()).BINARY; + case TIMESTAMP: + return n(type.getTimestamp().getNullability()).TIMESTAMP; + case DATE: + return n(type.getDate().getNullability()).DATE; + case TIME: + return n(type.getTime().getNullability()).TIME; + case INTERVAL_YEAR: + return n(type.getIntervalYear().getNullability()).INTERVAL_YEAR; + case INTERVAL_DAY: + return n(type.getIntervalDay().getNullability()) + // precision defaults to 6 (micros) for backwards compatibility, see protobuf + .intervalDay( + type.getIntervalDay().hasPrecision() ? type.getIntervalDay().getPrecision() : 6); + case INTERVAL_COMPOUND: + return n(type.getIntervalCompound().getNullability()) + .intervalCompound(type.getIntervalCompound().getPrecision()); + case TIMESTAMP_TZ: + return n(type.getTimestampTz().getNullability()).TIMESTAMP_TZ; + case UUID: + return n(type.getUuid().getNullability()).UUID; + case FIXED_CHAR: + return n(type.getFixedChar().getNullability()).fixedChar(type.getFixedChar().getLength()); + case VARCHAR: + return n(type.getVarchar().getNullability()).varChar(type.getVarchar().getLength()); + case FIXED_BINARY: + return n(type.getFixedBinary().getNullability()) + .fixedBinary(type.getFixedBinary().getLength()); + case DECIMAL: + return n(type.getDecimal().getNullability()) + .decimal(type.getDecimal().getPrecision(), type.getDecimal().getScale()); + case PRECISION_TIME: + return n(type.getPrecisionTime().getNullability()) + .precisionTime(type.getPrecisionTime().getPrecision()); + case PRECISION_TIMESTAMP: + return n(type.getPrecisionTimestamp().getNullability()) + .precisionTimestamp(type.getPrecisionTimestamp().getPrecision()); + case PRECISION_TIMESTAMP_TZ: + return n(type.getPrecisionTimestampTz().getNullability()) + .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); + case STRUCT: + return n(type.getStruct().getNullability()) + .struct( + type.getStruct().getTypesList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList())); + case LIST: + return fromList(type.getList()); + case MAP: + return fromMap(type.getMap()); + case USER_DEFINED: + { + io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); + SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); + return n(userDefined.getNullability()).userDefined(t.uri(), t.name()); + } + case USER_DEFINED_TYPE_REFERENCE: + throw new UnsupportedOperationException("Unsupported user defined reference: " + type); + case KIND_NOT_SET: + throw new UnsupportedOperationException("Type is not set: " + type); + default: + throw new UnsupportedOperationException("Unsupported type: " + type.getKindCase()); + } } public Type.ListType fromList(io.substrait.proto.Type.List list) { diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 1a7dcbc4f..44997bf2a 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -4,7 +4,9 @@ import io.substrait.function.ParameterizedType; import io.substrait.function.TypeExpression; import io.substrait.proto.DerivationExpression; +import io.substrait.proto.DerivationExpression.ReturnProgram.Assignment; import io.substrait.proto.Type; +import java.util.List; public class TypeExpressionProtoVisitor extends BaseProtoConverter { @@ -28,21 +30,7 @@ public BaseProtoTypes typeContainer( @Override public DerivationExpression visit(final TypeExpression.BinaryOperation expr) { - var opType = - switch (expr.opType()) { - case ADD -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_PLUS; - case SUBTRACT -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; - case MIN -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MIN; - case MAX -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MAX; - case LT -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; - // case LTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; - case GT -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_GREATER_THAN; - // case GTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; - // case NOT_EQ -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQ; - case EQ -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQUALS; - case COVERS -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_COVERS; - default -> throw new IllegalStateException("Unexpected value: " + expr.opType()); - }; + DerivationExpression.BinaryOp.BinaryOpType opType = getDerivationOpType(expr.opType()); return DerivationExpression.newBuilder() .setBinaryOp( DerivationExpression.BinaryOp.newBuilder() @@ -53,6 +41,33 @@ public DerivationExpression visit(final TypeExpression.BinaryOperation expr) { .build(); } + private DerivationExpression.BinaryOp.BinaryOpType getDerivationOpType( + TypeExpression.BinaryOperation.OpType type) { + switch (type) { + case ADD: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_PLUS; + case SUBTRACT: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; + case MIN: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MIN; + case MAX: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MAX; + case LT: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; + // case LTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN; + case GT: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_GREATER_THAN; + // case GTE -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS; + // case NOT_EQ -> DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQ; + case EQ: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_EQUALS; + case COVERS: + return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_COVERS; + default: + throw new IllegalStateException("Unexpected value: " + type); + } + } + @Override public DerivationExpression visit(final TypeExpression.NotOperation expr) { return DerivationExpression.newBuilder() @@ -82,7 +97,7 @@ public DerivationExpression visit(final TypeExpression.IntegerLiteral expr) { @Override public DerivationExpression visit(final TypeExpression.ReturnProgram expr) { - var assignments = + List assignments = expr.assignments().stream() .map( a -> @@ -91,7 +106,7 @@ public DerivationExpression visit(final TypeExpression.ReturnProgram expr) { .setExpression(a.expr().accept(this)) .build()) .collect(java.util.stream.Collectors.toList()); - var finalExpr = expr.finalExpression().accept(this); + DerivationExpression finalExpr = expr.finalExpression().accept(this); return DerivationExpression.newBuilder() .setReturnProgram( DerivationExpression.ReturnProgram.newBuilder() @@ -337,59 +352,62 @@ public DerivationExpression userDefined(int ref) { @Override protected DerivationExpression wrap(final Object o) { - var bldr = DerivationExpression.newBuilder(); - if (o instanceof Type.Boolean t) { - return bldr.setBool(t).build(); - } else if (o instanceof Type.I8 t) { - return bldr.setI8(t).build(); - } else if (o instanceof Type.I16 t) { - return bldr.setI16(t).build(); - } else if (o instanceof Type.I32 t) { - return bldr.setI32(t).build(); - } else if (o instanceof Type.I64 t) { - return bldr.setI64(t).build(); - } else if (o instanceof Type.FP32 t) { - return bldr.setFp32(t).build(); - } else if (o instanceof Type.FP64 t) { - return bldr.setFp64(t).build(); - } else if (o instanceof Type.String t) { - return bldr.setString(t).build(); - } else if (o instanceof Type.Binary t) { - return bldr.setBinary(t).build(); - } else if (o instanceof Type.Timestamp t) { - return bldr.setTimestamp(t).build(); - } else if (o instanceof Type.Date t) { - return bldr.setDate(t).build(); - } else if (o instanceof Type.Time t) { - return bldr.setTime(t).build(); - } else if (o instanceof Type.TimestampTZ t) { - return bldr.setTimestampTz(t).build(); - } else if (o instanceof Type.IntervalYear t) { - return bldr.setIntervalYear(t).build(); - } else if (o instanceof DerivationExpression.ExpressionIntervalDay t) { - return bldr.setIntervalDay(t).build(); - } else if (o instanceof DerivationExpression.ExpressionIntervalCompound t) { - return bldr.setIntervalCompound(t).build(); - } else if (o instanceof DerivationExpression.ExpressionFixedChar t) { - return bldr.setFixedChar(t).build(); - } else if (o instanceof DerivationExpression.ExpressionVarChar t) { - return bldr.setVarchar(t).build(); - } else if (o instanceof DerivationExpression.ExpressionFixedBinary t) { - return bldr.setFixedBinary(t).build(); - } else if (o instanceof DerivationExpression.ExpressionDecimal t) { - return bldr.setDecimal(t).build(); - } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestamp t) { - return bldr.setPrecisionTimestamp(t).build(); - } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestampTZ t) { - return bldr.setPrecisionTimestampTz(t).build(); - } else if (o instanceof DerivationExpression.ExpressionStruct t) { - return bldr.setStruct(t).build(); - } else if (o instanceof DerivationExpression.ExpressionList t) { - return bldr.setList(t).build(); - } else if (o instanceof DerivationExpression.ExpressionMap t) { - return bldr.setMap(t).build(); - } else if (o instanceof Type.UUID t) { - return bldr.setUuid(t).build(); + DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); + if (o instanceof Type.Boolean) { + return bldr.setBool((Type.Boolean) o).build(); + } else if (o instanceof Type.I8) { + return bldr.setI8((Type.I8) o).build(); + } else if (o instanceof Type.I16) { + return bldr.setI16((Type.I16) o).build(); + } else if (o instanceof Type.I32) { + return bldr.setI32((Type.I32) o).build(); + } else if (o instanceof Type.I64) { + return bldr.setI64((Type.I64) o).build(); + } else if (o instanceof Type.FP32) { + return bldr.setFp32((Type.FP32) o).build(); + } else if (o instanceof Type.FP64) { + return bldr.setFp64((Type.FP64) o).build(); + } else if (o instanceof Type.String) { + return bldr.setString((Type.String) o).build(); + } else if (o instanceof Type.Binary) { + return bldr.setBinary((Type.Binary) o).build(); + } else if (o instanceof Type.Timestamp) { + return bldr.setTimestamp((Type.Timestamp) o).build(); + } else if (o instanceof Type.Date) { + return bldr.setDate((Type.Date) o).build(); + } else if (o instanceof Type.Time) { + return bldr.setTime((Type.Time) o).build(); + } else if (o instanceof Type.TimestampTZ) { + return bldr.setTimestampTz((Type.TimestampTZ) o).build(); + } else if (o instanceof Type.IntervalYear) { + return bldr.setIntervalYear((Type.IntervalYear) o).build(); + } else if (o instanceof DerivationExpression.ExpressionIntervalDay) { + return bldr.setIntervalDay((DerivationExpression.ExpressionIntervalDay) o).build(); + } else if (o instanceof DerivationExpression.ExpressionIntervalCompound) { + return bldr.setIntervalCompound((DerivationExpression.ExpressionIntervalCompound) o) + .build(); + } else if (o instanceof DerivationExpression.ExpressionFixedChar) { + return bldr.setFixedChar((DerivationExpression.ExpressionFixedChar) o).build(); + } else if (o instanceof DerivationExpression.ExpressionVarChar) { + return bldr.setVarchar((DerivationExpression.ExpressionVarChar) o).build(); + } else if (o instanceof DerivationExpression.ExpressionFixedBinary) { + return bldr.setFixedBinary((DerivationExpression.ExpressionFixedBinary) o).build(); + } else if (o instanceof DerivationExpression.ExpressionDecimal) { + return bldr.setDecimal((DerivationExpression.ExpressionDecimal) o).build(); + } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestamp) { + return bldr.setPrecisionTimestamp((DerivationExpression.ExpressionPrecisionTimestamp) o) + .build(); + } else if (o instanceof DerivationExpression.ExpressionPrecisionTimestampTZ) { + return bldr.setPrecisionTimestampTz((DerivationExpression.ExpressionPrecisionTimestampTZ) o) + .build(); + } else if (o instanceof DerivationExpression.ExpressionStruct) { + return bldr.setStruct((DerivationExpression.ExpressionStruct) o).build(); + } else if (o instanceof DerivationExpression.ExpressionList) { + return bldr.setList((DerivationExpression.ExpressionList) o).build(); + } else if (o instanceof DerivationExpression.ExpressionMap) { + return bldr.setMap((DerivationExpression.ExpressionMap) o).build(); + } else if (o instanceof Type.UUID) { + return bldr.setUuid((Type.UUID) o).build(); } throw new UnsupportedOperationException("Unable to wrap type of " + o.getClass()); } diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 626b77166..5162b1c78 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -124,61 +124,61 @@ public Type userDefined(int ref) { @Override protected Type wrap(final Object o) { - var bldr = Type.newBuilder(); - if (o instanceof Type.Boolean t) { - return bldr.setBool(t).build(); - } else if (o instanceof Type.I8 t) { - return bldr.setI8(t).build(); - } else if (o instanceof Type.I16 t) { - return bldr.setI16(t).build(); - } else if (o instanceof Type.I32 t) { - return bldr.setI32(t).build(); - } else if (o instanceof Type.I64 t) { - return bldr.setI64(t).build(); - } else if (o instanceof Type.FP32 t) { - return bldr.setFp32(t).build(); - } else if (o instanceof Type.FP64 t) { - return bldr.setFp64(t).build(); - } else if (o instanceof Type.String t) { - return bldr.setString(t).build(); - } else if (o instanceof Type.Binary t) { - return bldr.setBinary(t).build(); - } else if (o instanceof Type.Timestamp t) { - return bldr.setTimestamp(t).build(); - } else if (o instanceof Type.Date t) { - return bldr.setDate(t).build(); - } else if (o instanceof Type.Time t) { - return bldr.setTime(t).build(); - } else if (o instanceof Type.TimestampTZ t) { - return bldr.setTimestampTz(t).build(); - } else if (o instanceof Type.IntervalYear t) { - return bldr.setIntervalYear(t).build(); - } else if (o instanceof Type.IntervalDay t) { - return bldr.setIntervalDay(t).build(); - } else if (o instanceof Type.IntervalCompound t) { - return bldr.setIntervalCompound(t).build(); - } else if (o instanceof Type.FixedChar t) { - return bldr.setFixedChar(t).build(); - } else if (o instanceof Type.VarChar t) { - return bldr.setVarchar(t).build(); - } else if (o instanceof Type.FixedBinary t) { - return bldr.setFixedBinary(t).build(); - } else if (o instanceof Type.Decimal t) { - return bldr.setDecimal(t).build(); - } else if (o instanceof Type.PrecisionTimestamp t) { - return bldr.setPrecisionTimestamp(t).build(); - } else if (o instanceof Type.PrecisionTimestampTZ t) { - return bldr.setPrecisionTimestampTz(t).build(); - } else if (o instanceof Type.Struct t) { - return bldr.setStruct(t).build(); - } else if (o instanceof Type.List t) { - return bldr.setList(t).build(); - } else if (o instanceof Type.Map t) { - return bldr.setMap(t).build(); - } else if (o instanceof Type.UUID t) { - return bldr.setUuid(t).build(); - } else if (o instanceof Type.UserDefined t) { - return bldr.setUserDefined(t).build(); + Type.Builder bldr = Type.newBuilder(); + if (o instanceof Type.Boolean) { + return bldr.setBool((Type.Boolean) o).build(); + } else if (o instanceof Type.I8) { + return bldr.setI8((Type.I8) o).build(); + } else if (o instanceof Type.I16) { + return bldr.setI16((Type.I16) o).build(); + } else if (o instanceof Type.I32) { + return bldr.setI32((Type.I32) o).build(); + } else if (o instanceof Type.I64) { + return bldr.setI64((Type.I64) o).build(); + } else if (o instanceof Type.FP32) { + return bldr.setFp32((Type.FP32) o).build(); + } else if (o instanceof Type.FP64) { + return bldr.setFp64((Type.FP64) o).build(); + } else if (o instanceof Type.String) { + return bldr.setString((Type.String) o).build(); + } else if (o instanceof Type.Binary) { + return bldr.setBinary((Type.Binary) o).build(); + } else if (o instanceof Type.Timestamp) { + return bldr.setTimestamp((Type.Timestamp) o).build(); + } else if (o instanceof Type.Date) { + return bldr.setDate((Type.Date) o).build(); + } else if (o instanceof Type.Time) { + return bldr.setTime((Type.Time) o).build(); + } else if (o instanceof Type.TimestampTZ) { + return bldr.setTimestampTz((Type.TimestampTZ) o).build(); + } else if (o instanceof Type.IntervalYear) { + return bldr.setIntervalYear((Type.IntervalYear) o).build(); + } else if (o instanceof Type.IntervalDay) { + return bldr.setIntervalDay((Type.IntervalDay) o).build(); + } else if (o instanceof Type.IntervalCompound) { + return bldr.setIntervalCompound((Type.IntervalCompound) o).build(); + } else if (o instanceof Type.FixedChar) { + return bldr.setFixedChar((Type.FixedChar) o).build(); + } else if (o instanceof Type.VarChar) { + return bldr.setVarchar((Type.VarChar) o).build(); + } else if (o instanceof Type.FixedBinary) { + return bldr.setFixedBinary((Type.FixedBinary) o).build(); + } else if (o instanceof Type.Decimal) { + return bldr.setDecimal((Type.Decimal) o).build(); + } else if (o instanceof Type.PrecisionTimestamp) { + return bldr.setPrecisionTimestamp((Type.PrecisionTimestamp) o).build(); + } else if (o instanceof Type.PrecisionTimestampTZ) { + return bldr.setPrecisionTimestampTz((Type.PrecisionTimestampTZ) o).build(); + } else if (o instanceof Type.Struct) { + return bldr.setStruct((Type.Struct) o).build(); + } else if (o instanceof Type.List) { + return bldr.setList((Type.List) o).build(); + } else if (o instanceof Type.Map) { + return bldr.setMap((Type.Map) o).build(); + } else if (o instanceof Type.UUID) { + return bldr.setUuid((Type.UUID) o).build(); + } else if (o instanceof Type.UserDefined) { + return bldr.setUserDefined((Type.UserDefined) o).build(); } throw new UnsupportedOperationException("Unable to wrap type of " + o.getClass()); } diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index f19bbef5b..cd2522090 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -73,8 +73,8 @@ void roundtripCustomType() { .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - var protoPlan = planProtoConverter.toProto(plan); - var planReturned = protoPlanConverter.from(protoPlan); + io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); + Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } @@ -99,8 +99,8 @@ void roundtripNumberedAnyTypes() { b.fieldReference(input, 0))) .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - var protoPlan = planProtoConverter.toProto(plan); - var planReturned = protoPlanConverter.from(protoPlan); + io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); + Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } } diff --git a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java index 9a84f137b..eeca3b134 100644 --- a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java +++ b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java @@ -25,8 +25,8 @@ public class ProtoRelConverterTest extends TestBase { @Nested class DefaultAdvancedExtensionTests { - static final StringHolder ENHANCED = new StringHolder("ENHANCED"); - static final StringHolder OPTIMIZED = new StringHolder("OPTIMIZED"); + final StringHolder enhanced = new StringHolder("ENHANCED"); + final StringHolder optimized = new StringHolder("OPTIMIZED"); Rel relWithExtension(AdvancedExtension advancedExtension) { return NamedScan.builder() @@ -38,12 +38,12 @@ Rel relWithExtension(AdvancedExtension advancedExtension) { Rel emptyAdvancedExtension = relWithExtension(AdvancedExtension.builder().build()); Rel advancedExtensionWithOptimization = - relWithExtension(AdvancedExtension.builder().addOptimizations(OPTIMIZED).build()); + relWithExtension(AdvancedExtension.builder().addOptimizations(optimized).build()); Rel advancedExtensionWithEnhancement = - relWithExtension(AdvancedExtension.builder().enhancement(ENHANCED).build()); + relWithExtension(AdvancedExtension.builder().enhancement(enhanced).build()); Rel advancedExtensionWithEnhancementAndOptimization = relWithExtension( - AdvancedExtension.builder().enhancement(ENHANCED).addOptimizations(OPTIMIZED).build()); + AdvancedExtension.builder().enhancement(enhanced).addOptimizations(optimized).build()); @Test void emptyAdvancedExtension() { diff --git a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java index abbcb24da..8a8118791 100644 --- a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java @@ -22,20 +22,21 @@ public class AggregateRoundtripTest extends TestBase { private void assertAggregateRoundtrip(Expression.AggregationInvocation invocation) { - var expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); + Expression.DecimalLiteral expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); Expression.StructLiteral literal = Expression.StructLiteral.builder().addFields(expression).build(); - var input = + io.substrait.relation.ImmutableVirtualTableScan input = VirtualTableScan.builder() .initialSchema(NamedStruct.of(Arrays.asList("decimal"), R.struct(R.decimal(10, 2)))) .addRows(literal) .build(); ExtensionCollector functionCollector = new ExtensionCollector(); - var to = new RelProtoConverter(functionCollector); - var extensions = defaultExtensionCollection; - var from = new ProtoRelConverter(functionCollector, extensions); + RelProtoConverter to = new RelProtoConverter(functionCollector); + io.substrait.extension.SimpleExtension.ExtensionCollection extensions = + defaultExtensionCollection; + ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions); - var measure = + io.substrait.relation.ImmutableMeasure measure = Aggregate.Measure.builder() .function( AggregateFunctionInvocation.builder() @@ -60,8 +61,9 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio .build()) .build(); - var aggRel = Aggregate.builder().input(input).measures(Arrays.asList(measure)).build(); - var protoAggRel = to.toProto(aggRel); + io.substrait.relation.ImmutableAggregate aggRel = + Aggregate.builder().input(input).measures(Arrays.asList(measure)).build(); + io.substrait.proto.Rel protoAggRel = to.toProto(aggRel); assertEquals( protoAggRel.getAggregate().getMeasuresList().get(0).getMeasure().getInvocation(), invocation.toProto()); @@ -70,7 +72,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio @Test void aggregateInvocationRoundtrip() { - for (var invocation : Expression.AggregationInvocation.values()) { + for (Expression.AggregationInvocation invocation : Expression.AggregationInvocation.values()) { assertAggregateRoundtrip(invocation); } } diff --git a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java index f82eedc12..9c434be17 100644 --- a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java @@ -17,7 +17,7 @@ public class ConsistentPartitionWindowRelRoundtripTest extends TestBase { @Test void consistentPartitionWindowRoundtripSingle() { - var windowFunctionDeclaration = + SimpleExtension.WindowFunctionVariant windowFunctionDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); @@ -70,11 +70,11 @@ void consistentPartitionWindowRoundtripSingle() { @Test void consistentPartitionWindowRoundtripMulti() { - var windowFunctionLeadDeclaration = + SimpleExtension.WindowFunctionVariant windowFunctionLeadDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); - var windowFunctionLagDeclaration = + SimpleExtension.WindowFunctionVariant windowFunctionLagDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 7b2237661..5b1ec5322 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -306,7 +306,7 @@ class ExtensionThroughExpression { @Test void scalarSubquery() { - var rel = + Project rel = b.project( input -> Stream.of( @@ -322,7 +322,7 @@ void scalarSubquery() { @Test void inPredicate() { - var rel = + Project rel = b.project( input -> Stream.of( @@ -337,7 +337,7 @@ void inPredicate() { @Test void setPredicate() { - var rel = + Project rel = b.project( input -> Stream.of( diff --git a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java index 1471154d8..47f53c110 100644 --- a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java @@ -51,8 +51,8 @@ public void roundtripTest(Method m, List paramInst, UnsupportedTypeGener // roundtrip to protobuff and back and check equality Expression val = (Expression) m.invoke(null, paramInst.toArray(new Object[0])); - var to = new ExpressionProtoConverter(null, null); - var from = + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = new ProtoExpressionConverter( null, null, diff --git a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java index cfcdaf6fc..7e613f272 100644 --- a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java @@ -25,8 +25,9 @@ void ifThenNotNullable() { ExpressionCreator.i64(false, 2)); assertFalse(ifRel.getType().nullable()); - var to = new ExpressionProtoConverter(null, null); - var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = + new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } @@ -39,8 +40,9 @@ void ifThenNullable() { ExpressionCreator.i64(false, 2)); assertTrue(ifRel.getType().nullable()); - var to = new ExpressionProtoConverter(null, null); - var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = + new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index b0f1b5fe3..ccac93bcb 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -15,9 +15,11 @@ public class LiteralRoundtripTest extends TestBase { @Test void decimal() { - var val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - var to = new ExpressionProtoConverter(null, null); - var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + io.substrait.expression.Expression.DecimalLiteral val = + ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); + ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + ProtoExpressionConverter from = + new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); } } diff --git a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java index 8489860bb..f19535f65 100644 --- a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java @@ -20,7 +20,7 @@ public class LocalFilesRoundtripTest extends TestBase { private void assertLocalFilesRoundtrip(FileOrFiles file) { - var builder = + io.substrait.relation.ImmutableLocalFiles.Builder builder = LocalFiles.builder() .initialSchema( NamedStruct.builder() @@ -48,8 +48,8 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) { ExpressionCreator.i32(false, 1))) .ifPresent(builder::filter); - var localFiles = builder.build(); - var protoFileRel = relProtoConverter.toProto(localFiles); + io.substrait.relation.ImmutableLocalFiles localFiles = builder.build(); + io.substrait.proto.Rel protoFileRel = relProtoConverter.toProto(localFiles); assertTrue(protoFileRel.getRead().hasFilter()); assertEquals(protoFileRel, relProtoConverter.toProto(protoRelConverter.from(protoFileRel))); } @@ -57,37 +57,53 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) { private ImmutableFileOrFiles.Builder setPath( ImmutableFileOrFiles.Builder builder, ReadRel.LocalFiles.FileOrFiles.PathTypeCase pathTypeCase) { - return switch (pathTypeCase) { - case URI_PATH -> builder.pathType(FileOrFiles.PathType.URI_PATH).path("path"); - case URI_PATH_GLOB -> builder.pathType(FileOrFiles.PathType.URI_PATH_GLOB).path("path"); - case URI_FILE -> builder.pathType(FileOrFiles.PathType.URI_FILE).path("path"); - case URI_FOLDER -> builder.pathType(FileOrFiles.PathType.URI_FOLDER).path("path"); - case PATHTYPE_NOT_SET -> builder; - }; + switch (pathTypeCase) { + case URI_PATH: + return builder.pathType(FileOrFiles.PathType.URI_PATH).path("path"); + case URI_PATH_GLOB: + return builder.pathType(FileOrFiles.PathType.URI_PATH_GLOB).path("path"); + case URI_FILE: + return builder.pathType(FileOrFiles.PathType.URI_FILE).path("path"); + case URI_FOLDER: + return builder.pathType(FileOrFiles.PathType.URI_FOLDER).path("path"); + case PATHTYPE_NOT_SET: + return builder; + default: + throw new IllegalArgumentException("Unknown path type case: " + pathTypeCase); + } } private ImmutableFileOrFiles.Builder setFileFormat( ImmutableFileOrFiles.Builder builder, ReadRel.LocalFiles.FileOrFiles.FileFormatCase fileFormatCase) { - return switch (fileFormatCase) { - case PARQUET -> builder.fileFormat(FileFormat.ParquetReadOptions.builder().build()); - case ARROW -> builder.fileFormat(FileFormat.ArrowReadOptions.builder().build()); - case ORC -> builder.fileFormat(FileFormat.OrcReadOptions.builder().build()); - case DWRF -> builder.fileFormat(FileFormat.DwrfReadOptions.builder().build()); - case TEXT -> builder.fileFormat( - FileFormat.DelimiterSeparatedTextReadOptions.builder() - .fieldDelimiter("|") - .maxLineSize(1000) - .quote("\"") - .headerLinesToSkip(1) - .escape("\\") - .build()); - case EXTENSION -> builder.fileFormat( - FileFormat.Extension.builder() - .extension(com.google.protobuf.Any.newBuilder().build()) - .build()); - case FILEFORMAT_NOT_SET -> builder; - }; + switch (fileFormatCase) { + case PARQUET: + return builder.fileFormat(FileFormat.ParquetReadOptions.builder().build()); + case ARROW: + return builder.fileFormat(FileFormat.ArrowReadOptions.builder().build()); + case ORC: + return builder.fileFormat(FileFormat.OrcReadOptions.builder().build()); + case DWRF: + return builder.fileFormat(FileFormat.DwrfReadOptions.builder().build()); + case TEXT: + return builder.fileFormat( + FileFormat.DelimiterSeparatedTextReadOptions.builder() + .fieldDelimiter("|") + .maxLineSize(1000) + .quote("\"") + .headerLinesToSkip(1) + .escape("\\") + .build()); + case EXTENSION: + return builder.fileFormat( + FileFormat.Extension.builder() + .extension(com.google.protobuf.Any.newBuilder().build()) + .build()); + case FILEFORMAT_NOT_SET: + return builder; + default: + throw new IllegalArgumentException("Unknown file format case: " + fileFormatCase); + } } @Test diff --git a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java index 99a2e2042..923ab1ecc 100644 --- a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java @@ -15,11 +15,11 @@ public class ReadRelRoundtripTest extends TestBase { @Test void namedScan() { - var tableName = Stream.of("a_table").collect(Collectors.toList()); - var columnNames = Stream.of("column1", "column2").collect(Collectors.toList()); + List tableName = Stream.of("a_table").collect(Collectors.toList()); + List columnNames = Stream.of("column1", "column2").collect(Collectors.toList()); List columnTypes = Stream.of(R.I64, R.I64).collect(Collectors.toList()); - var namedScan = b.namedScan(tableName, columnNames, columnTypes); + NamedScan namedScan = b.namedScan(tableName, columnNames, columnTypes); namedScan = NamedScan.builder() .from(namedScan) @@ -33,13 +33,13 @@ void namedScan() { @Test void emptyScan() { - var emptyScan = b.emptyScan(); + io.substrait.relation.EmptyScan emptyScan = b.emptyScan(); verifyRoundTrip(emptyScan); } @Test void virtualTable() { - var virtTable = + io.substrait.relation.ImmutableVirtualTableScan virtTable = VirtualTableScan.builder() .initialSchema( NamedStruct.of( diff --git a/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java b/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java index eb990ca42..9bd2e734f 100644 --- a/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java +++ b/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java @@ -54,7 +54,7 @@ public void roundtrip(boolean n) { * @param type */ private void t(Type type) { - var converted = type.accept(typeProtoConverter); + io.substrait.proto.Type converted = type.accept(typeProtoConverter); assertEquals(type, protoTypeConverter.from(converted)); } diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index 96135999f..1dc8f755b 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -102,11 +102,12 @@ public Integer call() throws Exception { } private void printMessage(Message message) throws IOException { - switch (outputFormat) { - case PROTOJSON -> System.out.println( - JsonFormat.printer().includingDefaultValueFields().print(message)); - case PROTOTEXT -> TextFormat.printer().print(message, System.out); - case BINARY -> message.writeTo(System.out); + if (outputFormat == OutputFormat.PROTOJSON) { + System.out.println(JsonFormat.printer().includingDefaultValueFields().print(message)); + } else if (outputFormat == OutputFormat.PROTOTEXT) { + TextFormat.printer().print(message, System.out); + } else if (outputFormat == OutputFormat.BINARY) { + message.writeTo(System.out); } } diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index 01103e023..64f9891a4 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -93,7 +93,7 @@ jreleaser { } java { - toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } + toolchain { languageVersion = JavaLanguageVersion.of(11) } withJavadocJar() withSourcesJar() } @@ -149,8 +149,6 @@ dependencies { "calcite-core brings in commons-lang:commons-lang:2.4 which has a security vulnerability" ) } - annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") testImplementation("com.google.protobuf:protobuf-java:${PROTOBUF_VERSION}") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 6a8d84d88..6cba80781 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -38,7 +38,8 @@ public class AggregateFunctions { * conversion was needed, empty otherwise. */ public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) { - if (aggFunction instanceof SqlMinMaxAggFunction fun) { + if (aggFunction instanceof SqlMinMaxAggFunction) { + SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction; return Optional.of( fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX); } else if (aggFunction instanceof SqlAvgAggFunction) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java index ba18be900..f33eaa4c8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java @@ -91,9 +91,12 @@ private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) { } private static boolean isSimpleFieldReference(FunctionArg e) { - return e instanceof FieldReference fr - && fr.segments().size() == 1 - && fr.segments().get(0) instanceof FieldReference.StructField; + if (!(e instanceof FieldReference)) { + return false; + } + + List segments = ((FieldReference) e).segments(); + return segments.size() == 1 && segments.get(0) instanceof FieldReference.StructField; } private static int getFieldRefOffset(FieldReference fr) { @@ -194,8 +197,8 @@ private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) { } private Expression projectOutNonFieldReference(FunctionArg farg) { - if ((farg instanceof Expression e)) { - return projectOutNonFieldReference(e); + if ((farg instanceof Expression)) { + return projectOutNonFieldReference((Expression) farg); } else { throw new IllegalArgumentException("cannot handle non-expression argument for aggregate"); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java index 15591f06e..81c4e9a49 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java @@ -94,38 +94,38 @@ protected RelNodeVisitor() {} * RelVisitor.reverseAccept(RelNode) due to the lack of ability to extend base classes. */ public final OUTPUT reverseAccept(RelNode node) throws EXCEPTION { - if (node instanceof TableScan scan) { - return this.visit(scan); - } else if (node instanceof TableFunctionScan scan) { - return this.visit(scan); - } else if (node instanceof Values values) { - return this.visit(values); - } else if (node instanceof Filter filter) { - return this.visit(filter); - } else if (node instanceof Calc calc) { - return this.visit(calc); - } else if (node instanceof Project project) { - return this.visit(project); - } else if (node instanceof Join join) { - return this.visit(join); - } else if (node instanceof Correlate correlate) { - return this.visit(correlate); - } else if (node instanceof Union union) { - return this.visit(union); - } else if (node instanceof Intersect intersect) { - return this.visit(intersect); - } else if (node instanceof Minus minus) { - return this.visit(minus); - } else if (node instanceof Match match) { - return this.visit(match); - } else if (node instanceof Sort sort) { - return this.visit(sort); - } else if (node instanceof Exchange exchange) { - return this.visit(exchange); - } else if (node instanceof Aggregate aggregate) { - return this.visit(aggregate); - } else if (node instanceof TableModify modify) { - return this.visit(modify); + if (node instanceof TableScan) { + return this.visit((TableScan) node); + } else if (node instanceof TableFunctionScan) { + return this.visit((TableFunctionScan) node); + } else if (node instanceof Values) { + return this.visit((Values) node); + } else if (node instanceof Filter) { + return this.visit((Filter) node); + } else if (node instanceof Calc) { + return this.visit((Calc) node); + } else if (node instanceof Project) { + return this.visit((Project) node); + } else if (node instanceof Join) { + return this.visit((Join) node); + } else if (node instanceof Correlate) { + return this.visit((Correlate) node); + } else if (node instanceof Union) { + return this.visit((Union) node); + } else if (node instanceof Intersect) { + return this.visit((Intersect) node); + } else if (node instanceof Minus) { + return this.visit((Minus) node); + } else if (node instanceof Match) { + return this.visit((Match) node); + } else if (node instanceof Sort) { + return this.visit((Sort) node); + } else if (node instanceof Exchange) { + return this.visit((Exchange) node); + } else if (node instanceof Aggregate) { + return this.visit((Aggregate) node); + } else if (node instanceof TableModify) { + return this.visit((TableModify) node); } else { return this.visitOther(node); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index bb7b99f9c..5a5cac5dc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import com.github.bsideup.jabel.Desugar; import io.substrait.extendedexpression.ExtendedExpression; import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; import io.substrait.extension.SimpleExtension; @@ -45,12 +44,23 @@ public SqlExpressionToSubstrait( this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); } - @Desugar - private record Result( - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) {} + private static final class Result { + final SqlValidator validator; + final CalciteCatalogReader catalogReader; + final Map nameToTypeMap; + final Map nameToNodeMap; + + Result( + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) { + this.validator = validator; + this.catalogReader = catalogReader; + this.nameToTypeMap = nameToTypeMap; + this.nameToNodeMap = nameToNodeMap; + } + } /** * Converts the given SQL expression to an {@link io.substrait.proto.ExtendedExpression } @@ -78,10 +88,10 @@ public io.substrait.proto.ExtendedExpression convert( var result = registerCreateTablesForExtendedExpression(createStatements); return executeInnerSQLExpressions( sqlExpressions, - result.validator(), - result.catalogReader(), - result.nameToTypeMap(), - result.nameToNodeMap()); + result.validator, + result.catalogReader, + result.nameToTypeMap, + result.nameToNodeMap); } private io.substrait.proto.ExtendedExpression executeInnerSQLExpressions( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 0f9bd08a6..804244c22 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList; import io.substrait.expression.Expression; +import io.substrait.expression.Expression.SortDirection; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; @@ -16,6 +17,7 @@ import io.substrait.relation.Fetch; import io.substrait.relation.Filter; import io.substrait.relation.Join; +import io.substrait.relation.Join.JoinType; import io.substrait.relation.LocalFiles; import io.substrait.relation.NamedScan; import io.substrait.relation.Project; @@ -201,28 +203,47 @@ public RelNode visit(Join join, Context context) throws RuntimeException { join.getCondition() .map(c -> c.accept(expressionRexConverter, context)) .orElse(relBuilder.literal(true)); - var joinType = - switch (join.getJoinType()) { - case INNER -> JoinRelType.INNER; - case LEFT -> JoinRelType.LEFT; - case RIGHT -> JoinRelType.RIGHT; - case OUTER -> JoinRelType.FULL; - case SEMI -> JoinRelType.SEMI; - case ANTI -> JoinRelType.ANTI; - case LEFT_SEMI -> JoinRelType.SEMI; - case LEFT_ANTI -> JoinRelType.ANTI; - case UNKNOWN -> throw new UnsupportedOperationException( - "Unknown join type is not supported"); - default -> throw new UnsupportedOperationException( - "Unsupported join type: " + join.getJoinType().name()); - }; + var joinType = asJoinRelType(join); RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build(); return applyRemap(node, join.getRemap()); } + private JoinRelType asJoinRelType(Join join) { + Join.JoinType type = join.getJoinType(); + + if (type == JoinType.INNER) { + return JoinRelType.INNER; + } + if (type == JoinType.LEFT) { + return JoinRelType.LEFT; + } + if (type == JoinType.RIGHT) { + return JoinRelType.RIGHT; + } + if (type == JoinType.OUTER) { + return JoinRelType.FULL; + } + if (type == JoinType.SEMI) { + return JoinRelType.SEMI; + } + if (type == JoinType.ANTI) { + return JoinRelType.ANTI; + } + if (type == JoinType.LEFT_SEMI) { + return JoinRelType.SEMI; + } + if (type == JoinType.LEFT_ANTI) { + return JoinRelType.ANTI; + } + if (type == JoinType.UNKNOWN) { + throw new UnsupportedOperationException("Unknown join type is not supported"); + } + + throw new UnsupportedOperationException("Unsupported join type: " + join.getJoinType().name()); + } + @Override public RelNode visit(Set set, Context context) throws RuntimeException { - int numInputs = set.getInputs().size(); set.getInputs() .forEach( input -> { @@ -232,22 +253,36 @@ public RelNode visit(Set set, Context context) throws RuntimeException { // correspond to the Calcite relations they are associated with. They are retained for now // to enable users to migrate off of them. // See: https://github.com/substrait-io/substrait-java/issues/303 - var builder = - switch (set.getSetOp()) { - case MINUS_PRIMARY -> relBuilder.minus(false, numInputs); - case MINUS_PRIMARY_ALL, MINUS_MULTISET -> relBuilder.minus(true, numInputs); - case INTERSECTION_PRIMARY, INTERSECTION_MULTISET -> relBuilder.intersect( - false, numInputs); - case INTERSECTION_MULTISET_ALL -> relBuilder.intersect(true, numInputs); - case UNION_DISTINCT -> relBuilder.union(false, numInputs); - case UNION_ALL -> relBuilder.union(true, numInputs); - case UNKNOWN -> throw new UnsupportedOperationException( - "Unknown set operation is not supported"); - }; + var builder = getRelBuilder(set); RelNode node = builder.build(); return applyRemap(node, set.getRemap()); } + private RelBuilder getRelBuilder(Set set) { + int numInputs = set.getInputs().size(); + + switch (set.getSetOp()) { + case MINUS_PRIMARY: + return relBuilder.minus(false, numInputs); + case MINUS_PRIMARY_ALL: + case MINUS_MULTISET: + return relBuilder.minus(true, numInputs); + case INTERSECTION_PRIMARY: + case INTERSECTION_MULTISET: + return relBuilder.intersect(false, numInputs); + case INTERSECTION_MULTISET_ALL: + return relBuilder.intersect(true, numInputs); + case UNION_DISTINCT: + return relBuilder.union(false, numInputs); + case UNION_ALL: + return relBuilder.union(true, numInputs); + case UNKNOWN: + throw new UnsupportedOperationException("Unknown set operation is not supported"); + default: + throw new UnsupportedOperationException("Unsupported set operation: " + set.getSetOp()); + } + } + @Override public RelNode visit(Aggregate aggregate, Context context) throws RuntimeException { if (!PreCalciteAggregateValidator.isValidCalciteAggregate(aggregate)) { @@ -366,14 +401,24 @@ private RexNode directedRexNode(Expression.SortField sortField, Context context) var expression = sortField.expr(); var rexNode = expression.accept(expressionRexConverter, context); var sortDirection = sortField.direction(); - return switch (sortDirection) { - case ASC_NULLS_FIRST -> relBuilder.nullsFirst(rexNode); - case ASC_NULLS_LAST -> relBuilder.nullsLast(rexNode); - case DESC_NULLS_FIRST -> relBuilder.nullsFirst(relBuilder.desc(rexNode)); - case DESC_NULLS_LAST -> relBuilder.nullsLast(relBuilder.desc(rexNode)); - case CLUSTERED -> throw new RuntimeException( - String.format("Unexpected Expression.SortDirection: Clustered!")); - }; + + if (sortDirection == Expression.SortDirection.ASC_NULLS_FIRST) { + return relBuilder.nullsFirst(rexNode); + } + if (sortDirection == Expression.SortDirection.ASC_NULLS_LAST) { + return relBuilder.nullsLast(rexNode); + } + if (sortDirection == Expression.SortDirection.DESC_NULLS_FIRST) { + return relBuilder.nullsFirst(relBuilder.desc(rexNode)); + } + if (sortDirection == Expression.SortDirection.DESC_NULLS_LAST) { + return relBuilder.nullsLast(relBuilder.desc(rexNode)); + } + if (sortDirection == Expression.SortDirection.CLUSTERED) { + throw new RuntimeException(String.format("Unexpected Expression.SortDirection: Clustered!")); + } + + throw new IllegalArgumentException("Unsupported sort direction: " + sortDirection); } @Override @@ -398,24 +443,30 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField, Co var sortDirection = sortField.direction(); RexSlot rexSlot = (RexSlot) rex; int fieldIndex = rexSlot.getIndex(); - var fieldDirection = RelFieldCollation.Direction.ASCENDING; - var nullDirection = RelFieldCollation.NullDirection.UNSPECIFIED; - switch (sortDirection) { - case ASC_NULLS_FIRST -> nullDirection = RelFieldCollation.NullDirection.FIRST; - case ASC_NULLS_LAST -> nullDirection = RelFieldCollation.NullDirection.LAST; - case DESC_NULLS_FIRST -> { - nullDirection = RelFieldCollation.NullDirection.FIRST; - fieldDirection = RelFieldCollation.Direction.DESCENDING; - } - case DESC_NULLS_LAST -> { - nullDirection = RelFieldCollation.NullDirection.LAST; - fieldDirection = RelFieldCollation.Direction.DESCENDING; - } - case CLUSTERED -> fieldDirection = RelFieldCollation.Direction.CLUSTERED; - - default -> throw new RuntimeException( + + final RelFieldCollation.Direction fieldDirection; + final RelFieldCollation.NullDirection nullDirection; + + if (sortDirection == SortDirection.ASC_NULLS_FIRST) { + fieldDirection = RelFieldCollation.Direction.ASCENDING; + nullDirection = RelFieldCollation.NullDirection.FIRST; + } else if (sortDirection == SortDirection.ASC_NULLS_LAST) { + fieldDirection = RelFieldCollation.Direction.ASCENDING; + nullDirection = RelFieldCollation.NullDirection.LAST; + } else if (sortDirection == SortDirection.DESC_NULLS_FIRST) { + nullDirection = RelFieldCollation.NullDirection.FIRST; + fieldDirection = RelFieldCollation.Direction.DESCENDING; + } else if (sortDirection == SortDirection.DESC_NULLS_LAST) { + nullDirection = RelFieldCollation.NullDirection.LAST; + fieldDirection = RelFieldCollation.Direction.DESCENDING; + } else if (sortDirection == SortDirection.CLUSTERED) { + fieldDirection = RelFieldCollation.Direction.CLUSTERED; + nullDirection = RelFieldCollation.NullDirection.UNSPECIFIED; + } else { + throw new RuntimeException( String.format("Unexpected Expression.SortDirection enum: %s !", sortDirection)); } + return new RelFieldCollation(fieldIndex, fieldDirection, nullDirection); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index fdb3f8aec..2e31718f9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -32,9 +32,11 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelFieldCollation.Direction; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; @@ -170,17 +172,7 @@ public Rel visit(org.apache.calcite.rel.core.Join join) { var left = apply(join.getLeft()); var right = apply(join.getRight()); var condition = toExpression(join.getCondition()); - var joinType = - switch (join.getJoinType()) { - case INNER -> Join.JoinType.INNER; - case LEFT -> Join.JoinType.LEFT; - case RIGHT -> Join.JoinType.RIGHT; - case FULL -> Join.JoinType.OUTER; - case SEMI -> Join.JoinType.LEFT_SEMI; - case ANTI -> Join.JoinType.LEFT_ANTI; - default -> throw new UnsupportedOperationException( - "Unsupported join type: " + join.getJoinType()); - }; + var joinType = asJoinType(join); // An INNER JOIN with a join condition of TRUE can be encoded as a Substrait Cross relation if (joinType == Join.JoinType.INNER && TRUE.equals(condition)) { @@ -189,6 +181,26 @@ public Rel visit(org.apache.calcite.rel.core.Join join) { return Join.builder().condition(condition).joinType(joinType).left(left).right(right).build(); } + private Join.JoinType asJoinType(org.apache.calcite.rel.core.Join join) { + JoinRelType type = join.getJoinType(); + + if (type == JoinRelType.INNER) { + return Join.JoinType.INNER; + } else if (type == JoinRelType.LEFT) { + return Join.JoinType.LEFT; + } else if (type == JoinRelType.RIGHT) { + return Join.JoinType.RIGHT; + } else if (type == JoinRelType.FULL) { + return Join.JoinType.OUTER; + } else if (type == JoinRelType.SEMI) { + return Join.JoinType.LEFT_SEMI; + } else if (type == JoinRelType.ANTI) { + return Join.JoinType.LEFT_ANTI; + } + + throw new UnsupportedOperationException("Unsupported join type: " + join.getJoinType()); + } + @Override public Rel visit(org.apache.calcite.rel.core.Correlate correlate) { // left input of correlated-join is similar to the left input of a logical join @@ -197,16 +209,22 @@ public Rel visit(org.apache.calcite.rel.core.Correlate correlate) { // right input of correlated-join is similar to a correlated sub-query apply(correlate.getRight()); - var joinType = - switch (correlate.getJoinType()) { - case INNER -> Join.JoinType.INNER; // corresponds to CROSS APPLY join - case LEFT -> Join.JoinType.LEFT; // corresponds to OUTER APPLY join - default -> throw new IllegalArgumentException( - "Invalid correlated join type: " + correlate.getJoinType()); - }; + var joinType = asJoinType(correlate); return super.visit(correlate); } + private Join.JoinType asJoinType(org.apache.calcite.rel.core.Correlate correlate) { + JoinRelType type = correlate.getJoinType(); + + if (type == JoinRelType.INNER) { + return Join.JoinType.INNER; + } else if (type == JoinRelType.LEFT) { + return Join.JoinType.LEFT; + } + + throw new IllegalArgumentException("Invalid correlated join type: " + correlate.getJoinType()); + } + @Override public Rel visit(org.apache.calcite.rel.core.Union union) { var inputs = apply(union.getInputs()); @@ -316,28 +334,17 @@ public Rel visit(org.apache.calcite.rel.core.Sort sort) { private long asLong(RexNode rex) { var expr = toExpression(rex); - if (expr instanceof Expression.I64Literal i) { - return i.value(); - } else if (expr instanceof Expression.I32Literal i) { - return i.value(); + if (expr instanceof Expression.I64Literal) { + return ((Expression.I64Literal) expr).value(); + } else if (expr instanceof Expression.I32Literal) { + return ((Expression.I32Literal) expr).value(); } throw new UnsupportedOperationException("Unknown type: " + rex); } public static Expression.SortField toSortField( RelFieldCollation collation, Type.Struct inputType) { - Expression.SortDirection direction = - switch (collation.direction) { - case STRICTLY_ASCENDING, ASCENDING -> collation.nullDirection - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.ASC_NULLS_LAST - : Expression.SortDirection.ASC_NULLS_FIRST; - case STRICTLY_DESCENDING, DESCENDING -> collation.nullDirection - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.DESC_NULLS_LAST - : Expression.SortDirection.DESC_NULLS_FIRST; - case CLUSTERED -> Expression.SortDirection.CLUSTERED; - }; + Expression.SortDirection direction = asSortDirection(collation); return Expression.SortField.builder() .expr(FieldReference.newRootStructReference(collation.getFieldIndex(), inputType)) @@ -345,6 +352,24 @@ public static Expression.SortField toSortField( .build(); } + private static Expression.SortDirection asSortDirection(RelFieldCollation collation) { + RelFieldCollation.Direction direction = collation.direction; + + if (direction == Direction.STRICTLY_ASCENDING || direction == Direction.ASCENDING) { + return collation.nullDirection == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.ASC_NULLS_LAST + : Expression.SortDirection.ASC_NULLS_FIRST; + } else if (direction == Direction.STRICTLY_DESCENDING || direction == Direction.DESCENDING) { + return collation.nullDirection == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.DESC_NULLS_LAST + : Expression.SortDirection.DESC_NULLS_FIRST; + } else if (direction == Direction.CLUSTERED) { + return Expression.SortDirection.CLUSTERED; + } + + throw new IllegalArgumentException("Unsupported collation direction: " + direction); + } + @Override public Rel visit(org.apache.calcite.rel.core.Exchange exchange) { return super.visit(exchange); diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index c48943ca1..ef1e919ee 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -65,63 +65,90 @@ private Type toSubstrait(RelDataType type, List names) { } TypeCreator creator = Type.withNullability(type.isNullable()); - return switch (type.getSqlTypeName()) { - case BOOLEAN -> creator.BOOLEAN; - case TINYINT -> creator.I8; - case SMALLINT -> creator.I16; - case INTEGER -> creator.I32; - case BIGINT -> creator.I64; - case REAL -> creator.FP32; - case FLOAT, DOUBLE -> creator.FP64; - case DECIMAL -> { - if (type.getPrecision() > 38) { - throw new UnsupportedOperationException( - "unsupported decimal precision " + type.getPrecision()); + + switch (type.getSqlTypeName()) { + case BOOLEAN: + return creator.BOOLEAN; + case TINYINT: + return creator.I8; + case SMALLINT: + return creator.I16; + case INTEGER: + return creator.I32; + case BIGINT: + return creator.I64; + case REAL: + return creator.FP32; + case FLOAT: + case DOUBLE: + return creator.FP64; + case DECIMAL: + { + if (type.getPrecision() > 38) { + throw new UnsupportedOperationException( + "unsupported decimal precision " + type.getPrecision()); + } + return creator.decimal(type.getPrecision(), type.getScale()); } - yield creator.decimal(type.getPrecision(), type.getScale()); - } - case CHAR -> creator.fixedChar(type.getPrecision()); - case VARCHAR -> { - if (type.getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { - yield creator.STRING; + case CHAR: + return creator.fixedChar(type.getPrecision()); + case VARCHAR: + { + if (type.getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { + return creator.STRING; + } + return creator.varChar(type.getPrecision()); } - yield creator.varChar(type.getPrecision()); - } - case SYMBOL -> creator.STRING; - case DATE -> creator.DATE; - case TIME -> creator.TIME; - case TIMESTAMP -> creator.precisionTimestamp(type.getPrecision()); - case TIMESTAMP_WITH_LOCAL_TIME_ZONE -> creator.precisionTimestampTZ(type.getPrecision()); - case INTERVAL_YEAR, INTERVAL_YEAR_MONTH, INTERVAL_MONTH -> creator.INTERVAL_YEAR; - case INTERVAL_DAY, - INTERVAL_DAY_HOUR, - INTERVAL_DAY_MINUTE, - INTERVAL_DAY_SECOND, - INTERVAL_HOUR, - INTERVAL_HOUR_MINUTE, - INTERVAL_HOUR_SECOND, - INTERVAL_MINUTE, - INTERVAL_MINUTE_SECOND, - INTERVAL_SECOND -> creator.intervalDay(type.getScale()); - case VARBINARY -> creator.BINARY; - case BINARY -> creator.fixedBinary(type.getPrecision()); - case MAP -> { - MapSqlType map = (MapSqlType) type; - yield creator.map( - toSubstrait(map.getKeyType(), names), toSubstrait(map.getValueType(), names)); - } - case ROW -> { - var children = new ArrayList(); - for (var field : type.getFieldList()) { - names.add(field.getName()); - children.add(toSubstrait(field.getType(), names)); + case SYMBOL: + return creator.STRING; + case DATE: + return creator.DATE; + case TIME: + return creator.TIME; + case TIMESTAMP: + return creator.precisionTimestamp(type.getPrecision()); + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + return creator.precisionTimestampTZ(type.getPrecision()); + case INTERVAL_YEAR: + case INTERVAL_YEAR_MONTH: + case INTERVAL_MONTH: + return creator.INTERVAL_YEAR; + case INTERVAL_DAY: + case INTERVAL_DAY_HOUR: + case INTERVAL_DAY_MINUTE: + case INTERVAL_DAY_SECOND: + case INTERVAL_HOUR: + case INTERVAL_HOUR_MINUTE: + case INTERVAL_HOUR_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_MINUTE_SECOND: + case INTERVAL_SECOND: + return creator.intervalDay(type.getScale()); + case VARBINARY: + return creator.BINARY; + case BINARY: + return creator.fixedBinary(type.getPrecision()); + case MAP: + { + MapSqlType map = (MapSqlType) type; + return creator.map( + toSubstrait(map.getKeyType(), names), toSubstrait(map.getValueType(), names)); } - yield creator.struct(children); - } - case ARRAY -> creator.list(toSubstrait(type.getComponentType(), names)); - default -> throw new UnsupportedOperationException( - String.format("Unable to convert the type " + type.toString())); - }; + case ROW: + { + var children = new ArrayList(); + for (var field : type.getFieldList()) { + names.add(field.getName()); + children.add(toSubstrait(field.getType(), names)); + } + return creator.struct(children); + } + case ARRAY: + return creator.list(toSubstrait(type.getComponentType(), names)); + default: + throw new UnsupportedOperationException( + String.format("Unable to convert the type " + type.toString())); + } } public RelDataType toCalcite( @@ -343,14 +370,17 @@ private boolean n(NullableType type) { } private RelDataType t(boolean nullable, SqlTypeName typeName, Integer... props) { - final RelDataType baseType = - switch (props.length) { - case 0 -> typeFactory.createSqlType(typeName); - case 1 -> typeFactory.createSqlType(typeName, props[0]); - case 2 -> typeFactory.createSqlType(typeName, props[0], props[1]); - default -> throw new IllegalArgumentException( - "Unexpected properties length: " + Arrays.toString(props)); - }; + final RelDataType baseType; + if (props.length == 0) { + baseType = typeFactory.createSqlType(typeName); + } else if (props.length == 1) { + baseType = typeFactory.createSqlType(typeName, props[0]); + } else if (props.length == 2) { + baseType = typeFactory.createSqlType(typeName, props[0], props[1]); + } else { + throw new IllegalArgumentException( + "Unexpected properties length: " + Arrays.toString(props)); + } return typeFactory.createTypeWithNullability(baseType, nullable); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 29039be8d..4009df779 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -108,7 +108,7 @@ protected FunctionFinder getFunctionFinder(AggregateCall call) { return signatures.get(lookupFunction); } - static class WrappedAggregateCall implements GenericCall { + static class WrappedAggregateCall implements FunctionConverter.GenericCall { private final AggregateCall call; private final RelNode input; private final RexBuilder rexBuilder; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 36fc68b37..810754e54 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -66,8 +66,11 @@ public class CallConverters { // For now, we only support handling of SqlKind.REINTEPRETET for the case of stored // user-defined literals - if (operand instanceof Expression.FixedBinaryLiteral literal - && type instanceof Type.UserDefined t) { + if (operand instanceof Expression.FixedBinaryLiteral + && type instanceof Type.UserDefined) { + Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; + Type.UserDefined t = (Type.UserDefined) type; + return Expression.UserDefinedLiteral.builder() .uri(t.uri()) .name(t.name()) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index d962dc63a..cec68fcb7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -117,8 +117,8 @@ private static Optional findEnumArg( return Optional.empty(); } Argument arg = args.get(enumAnchor.argIdx); - if (arg instanceof SimpleExtension.EnumArgument ea) { - return Optional.of(ea); + if (arg instanceof SimpleExtension.EnumArgument) { + return Optional.of((SimpleExtension.EnumArgument) arg); } else { return Optional.empty(); } @@ -127,19 +127,21 @@ private static Optional findEnumArg( static Optional fromRex( SimpleExtension.Function function, RexLiteral literal, int argIdx) { - return switch (literal.getType().getSqlTypeName()) { - case SYMBOL -> { - Object v = literal.getValue(); - if (!literal.isNull() && (v instanceof Enum)) { - Enum value = (Enum) v; - ArgAnchor enumAnchor = argAnchor(function, argIdx); - yield findEnumArg(function, enumAnchor).map(ea -> EnumArg.of(ea, value.name())); - } else { - yield Optional.empty(); + switch (literal.getType().getSqlTypeName()) { + case SYMBOL: + { + Object v = literal.getValue(); + if (!literal.isNull() && (v instanceof Enum)) { + Enum value = (Enum) v; + ArgAnchor enumAnchor = argAnchor(function, argIdx); + return findEnumArg(function, enumAnchor).map(ea -> EnumArg.of(ea, value.name())); + } + + return Optional.empty(); } - } - default -> Optional.empty(); - }; + default: + return Optional.empty(); + } } static boolean canConvert(Enum value) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index bd433ccbe..e6d03b663 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -424,15 +424,7 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) expr.sort().stream() .map( sf -> { - Set direction = - switch (sf.direction()) { - case ASC_NULLS_FIRST -> Set.of(SqlKind.NULLS_FIRST); - case ASC_NULLS_LAST -> Set.of(SqlKind.NULLS_LAST); - case DESC_NULLS_FIRST -> Set.of(SqlKind.DESCENDING, SqlKind.NULLS_FIRST); - case DESC_NULLS_LAST -> Set.of(SqlKind.DESCENDING, SqlKind.NULLS_LAST); - case CLUSTERED -> throw new IllegalArgumentException( - "SORT_DIRECTION_CLUSTERED is not supported"); - }; + Set direction = asSqlKind(sf.direction()); return new RexFieldCollation(sf.expr().accept(this, context), direction); }) .collect(ImmutableList.toImmutableList()); @@ -440,19 +432,8 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) RexWindowBound lowerBound = ToRexWindowBound.lowerBound(rexBuilder, expr.lowerBound()); RexWindowBound upperBound = ToRexWindowBound.upperBound(rexBuilder, expr.upperBound()); - boolean rowMode = - switch (expr.boundsType()) { - case ROWS -> true; - case RANGE -> false; - case UNSPECIFIED -> throw new IllegalArgumentException( - "bounds type on window function must be specified"); - }; - - boolean distinct = - switch (expr.invocation()) { - case UNSPECIFIED, ALL -> false; - case DISTINCT -> true; - }; + boolean rowMode = isRowMode(expr); + boolean distinct = isDistinct(expr); // For queries like: SELECT last_value() IGNORE NULLS OVER ... // Substrait has no mechanism to set this, so by default it is false @@ -478,6 +459,53 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) ignoreNulls); } + private Set asSqlKind(Expression.SortDirection direction) { + switch (direction) { + case ASC_NULLS_FIRST: + return Set.of(SqlKind.NULLS_FIRST); + case ASC_NULLS_LAST: + return Set.of(SqlKind.NULLS_LAST); + case DESC_NULLS_FIRST: + return Set.of(SqlKind.DESCENDING, SqlKind.NULLS_FIRST); + case DESC_NULLS_LAST: + return Set.of(SqlKind.DESCENDING, SqlKind.NULLS_LAST); + case CLUSTERED: + throw new IllegalArgumentException("SORT_DIRECTION_CLUSTERED is not supported"); + default: + throw new IllegalArgumentException("Unsupported sort direction: " + direction); + } + } + + private boolean isRowMode(Expression.WindowFunctionInvocation expr) { + Expression.WindowBoundsType boundsType = expr.boundsType(); + + switch (boundsType) { + case ROWS: + return true; + case RANGE: + return false; + case UNSPECIFIED: + throw new IllegalArgumentException("bounds type on window function must be specified"); + default: + throw new IllegalArgumentException( + "Unsupported window function bounds type: " + boundsType); + } + } + + private boolean isDistinct(Expression.WindowFunctionInvocation expr) { + Expression.AggregationInvocation invocation = expr.invocation(); + + switch (invocation) { + case UNSPECIFIED: + case ALL: + return false; + case DISTINCT: + return true; + default: + throw new IllegalArgumentException("Unsupported window function invocation: " + invocation); + } + } + @Override public RexNode visit(Expression.InPredicate expr, Context context) throws RuntimeException { List needles = @@ -534,12 +562,12 @@ public RexWindowBound visit(WindowBound.Unbounded unbounded) { private String convert(FunctionArg a) { String v; - if (a instanceof EnumArg ea) { - v = ea.value().toString(); - } else if (a instanceof Expression e) { - v = e.getType().accept(new StringTypeVisitor()); - } else if (a instanceof Type t) { - v = t.accept(new StringTypeVisitor()); + if (a instanceof EnumArg) { + v = ((EnumArg) a).value().toString(); + } else if (a instanceof Expression) { + v = ((Expression) a).getType().accept(new StringTypeVisitor()); + } else if (a instanceof Type) { + v = ((Type) a).accept(new StringTypeVisitor()); } else { throw new IllegalStateException("Unexpected value: " + a); } @@ -561,7 +589,8 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti var segment = expr.segments().get(0); RexInputRef rexInputRef; - if (segment instanceof FieldReference.StructField f) { + if (segment instanceof FieldReference.StructField) { + FieldReference.StructField f = (FieldReference.StructField) segment; rexInputRef = new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); } else { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java index ab6233f54..1cb057527 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java @@ -92,14 +92,14 @@ public Optional convert( } private Optional toInt(Expression.Literal l) { - if (l instanceof Expression.I8Literal i8) { - return Optional.of(i8.value()); - } else if (l instanceof Expression.I16Literal i16) { - return Optional.of(i16.value()); - } else if (l instanceof Expression.I32Literal i32) { - return Optional.of(i32.value()); - } else if (l instanceof Expression.I64Literal i64) { - return Optional.of((int) i64.value()); + if (l instanceof Expression.I8Literal) { + return Optional.of(((Expression.I8Literal) l).value()); + } else if (l instanceof Expression.I16Literal) { + return Optional.of(((Expression.I16Literal) l).value()); + } else if (l instanceof Expression.I32Literal) { + return Optional.of(((Expression.I32Literal) l).value()); + } else if (l instanceof Expression.I64Literal) { + return Optional.of((int) ((Expression.I64Literal) l).value()); } logger.atWarn().log("Literal expected to be int type but was not. {}.", l); return Optional.empty(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 399c63b3e..b2ad4c5dc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -203,7 +203,7 @@ && inputTypesMatchDefinedArguments(inputTypes, args)) { * @param args expected arguments as defined in a {@link SimpleExtension.Function} * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise */ - private static boolean inputTypesMatchDefinedArguments( + private boolean inputTypesMatchDefinedArguments( List inputTypes, List args) { Map> wildcardToType = new HashMap<>(); @@ -242,9 +242,8 @@ private static boolean inputTypesMatchDefinedArguments( *

If this exists, the function finder will attempt to find a least-restrictive match using * these. */ - private static - Optional> getSingularInputType(List functions) { - List matchers = new ArrayList<>(); + private Optional> getSingularInputType(List functions) { + List> matchers = new ArrayList<>(); for (var f : functions) { ParameterizedType firstType = null; @@ -274,15 +273,17 @@ Optional> getSingularInputType(List functions) { } } - return switch (matchers.size()) { - case 0 -> Optional.empty(); - case 1 -> Optional.of(matchers.get(0)); - default -> Optional.of(chained(matchers)); - }; + switch (matchers.size()) { + case 0: + return Optional.empty(); + case 1: + return Optional.of(matchers.get(0)); + default: + return Optional.of(chained(matchers)); + } } - public static SingularArgumentMatcher singular( - F function, ParameterizedType type) { + private SingularArgumentMatcher singular(F function, ParameterizedType type) { return (inputType, outputType) -> { var check = isMatch(inputType, type); if (check) { @@ -292,7 +293,7 @@ public static SingularArgumentMatcher si }; } - public static SingularArgumentMatcher chained(List matchers) { + private SingularArgumentMatcher chained(List> matchers) { return (inputType, outputType) -> { for (var s : matchers) { var outcome = s.tryMatch(inputType, outputType); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 47f63908d..b5aee4168 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -89,112 +89,139 @@ public Expression.Literal convert(RexLiteral literal) { return typedNull(type); } - return switch (literal.getType().getSqlTypeName()) { - case TINYINT -> i8(n, i(literal).intValue()); - case SMALLINT -> i16(n, i(literal).intValue()); - case INTEGER -> i32(n, i(literal).intValue()); - case BIGINT -> i64(n, i(literal).longValue()); - case BOOLEAN -> bool(n, literal.getValueAs(Boolean.class)); - case CHAR -> { - var val = literal.getValue(); - if (val instanceof NlsString nls) { - yield fixedChar(n, nls.getValue()); + switch (literal.getType().getSqlTypeName()) { + case TINYINT: + return i8(n, i(literal).intValue()); + case SMALLINT: + return i16(n, i(literal).intValue()); + case INTEGER: + return i32(n, i(literal).intValue()); + case BIGINT: + return i64(n, i(literal).longValue()); + case BOOLEAN: + return bool(n, literal.getValueAs(Boolean.class)); + case CHAR: + { + var val = literal.getValue(); + if (val instanceof NlsString) { + var nls = (NlsString) val; + return fixedChar(n, nls.getValue()); + } + throw new UnsupportedOperationException("Unable to handle char type: " + val); } - throw new UnsupportedOperationException("Unable to handle char type: " + val); - } - case FLOAT, DOUBLE -> fp64(n, literal.getValueAs(Double.class)); - case REAL -> fp32(n, literal.getValueAs(Float.class)); - - case DECIMAL -> { - BigDecimal bd = bd(literal); - yield decimal(n, bd, literal.getType().getPrecision(), literal.getType().getScale()); - } - case VARCHAR -> { - if (literal.getType().getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { - yield string(n, s(literal)); + case FLOAT: + case DOUBLE: + return fp64(n, literal.getValueAs(Double.class)); + case REAL: + return fp32(n, literal.getValueAs(Float.class)); + + case DECIMAL: + { + BigDecimal bd = bd(literal); + return decimal(n, bd, literal.getType().getPrecision(), literal.getType().getScale()); } + case VARCHAR: + { + if (literal.getType().getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { + return string(n, s(literal)); + } - yield varChar(n, s(literal), literal.getType().getPrecision()); - } - case BINARY -> fixedBinary( - n, - ByteString.copyFrom( - padRightIfNeeded( - literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class), - literal.getType().getPrecision()))); - case VARBINARY -> binary(n, ByteString.copyFrom(literal.getValueAs(byte[].class))); - case SYMBOL -> { - Object value = literal.getValue(); - // case TimeUnitRange tur -> string(n, tur.name()); - if (value instanceof NlsString s) { - yield string(n, s.getValue()); - } else if (value instanceof Enum v) { - Optional r = - EnumConverter.canConvert(v) ? Optional.of(string(n, v.name())) : Optional.empty(); - yield r.orElseThrow( - () -> new UnsupportedOperationException("Unable to handle symbol: " + value)); - } else { - throw new UnsupportedOperationException("Unable to handle symbol: " + value); + return varChar(n, s(literal), literal.getType().getPrecision()); } - } - case DATE -> { - DateString date = literal.getValueAs(DateString.class); - LocalDate localDate = LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER); - yield ExpressionCreator.date(n, (int) localDate.toEpochDay()); - } - case TIME -> { - TimeString time = literal.getValueAs(TimeString.class); - LocalTime localTime = LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER); - yield time(n, NANOSECONDS.toMicros(localTime.toNanoOfDay())); - } - case TIMESTAMP, TIMESTAMP_WITH_LOCAL_TIME_ZONE -> { - TimestampString timestamp = literal.getValueAs(TimestampString.class); - LocalDateTime ldt = - LocalDateTime.parse(timestamp.toString(), CALCITE_LOCAL_DATETIME_FORMATTER); - yield timestamp(n, ldt); - } - case INTERVAL_YEAR, INTERVAL_YEAR_MONTH, INTERVAL_MONTH -> { - long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class)); - var years = intervalLength / 12; - var months = intervalLength - years * 12; - yield intervalYear(n, (int) years, (int) months); - } - case INTERVAL_DAY, - INTERVAL_DAY_HOUR, - INTERVAL_DAY_MINUTE, - INTERVAL_DAY_SECOND, - INTERVAL_HOUR, - INTERVAL_HOUR_MINUTE, - INTERVAL_HOUR_SECOND, - INTERVAL_MINUTE, - INTERVAL_MINUTE_SECOND, - INTERVAL_SECOND -> { - // Calcite represents day/time intervals in milliseconds, despite a default scale of 6. - var totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class)); - var interval = Duration.ofMillis(totalMillis); - - var days = interval.toDays(); - var seconds = interval.minusDays(days).toSeconds(); - var micros = interval.toMillisPart() * 1000; - - yield intervalDay(n, (int) days, (int) seconds, micros, 6); - } - - case ROW -> { - List literals = (List) literal.getValue(); - yield struct(n, literals.stream().map(this::convert).collect(Collectors.toList())); - } - - case ARRAY -> { - List literals = (List) literal.getValue(); - yield list(n, literals.stream().map(this::convert).collect(Collectors.toList())); - } - - default -> throw new UnsupportedOperationException( - String.format( - "Unable to convert the value of %s of type %s to a literal.", - literal, literal.getType().getSqlTypeName())); - }; + case BINARY: + return fixedBinary( + n, + ByteString.copyFrom( + padRightIfNeeded( + literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class), + literal.getType().getPrecision()))); + case VARBINARY: + return binary(n, ByteString.copyFrom(literal.getValueAs(byte[].class))); + case SYMBOL: + { + Object value = literal.getValue(); + // case TimeUnitRange tur -> string(n, tur.name()); + if (value instanceof NlsString) { + return string(n, ((NlsString) value).getValue()); + } else if (value instanceof Enum) { + Enum v = (Enum) value; + Optional r = + EnumConverter.canConvert(v) ? Optional.of(string(n, v.name())) : Optional.empty(); + return r.orElseThrow( + () -> new UnsupportedOperationException("Unable to handle symbol: " + value)); + } else { + throw new UnsupportedOperationException("Unable to handle symbol: " + value); + } + } + case DATE: + { + DateString date = literal.getValueAs(DateString.class); + LocalDate localDate = LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER); + return ExpressionCreator.date(n, (int) localDate.toEpochDay()); + } + case TIME: + { + TimeString time = literal.getValueAs(TimeString.class); + LocalTime localTime = LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER); + return time(n, NANOSECONDS.toMicros(localTime.toNanoOfDay())); + } + case TIMESTAMP: + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + { + TimestampString timestamp = literal.getValueAs(TimestampString.class); + LocalDateTime ldt = + LocalDateTime.parse(timestamp.toString(), CALCITE_LOCAL_DATETIME_FORMATTER); + return timestamp(n, ldt); + } + case INTERVAL_YEAR: + case INTERVAL_YEAR_MONTH: + case INTERVAL_MONTH: + { + long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class)); + var years = intervalLength / 12; + var months = intervalLength - years * 12; + return intervalYear(n, (int) years, (int) months); + } + case INTERVAL_DAY: + case INTERVAL_DAY_HOUR: + case INTERVAL_DAY_MINUTE: + case INTERVAL_DAY_SECOND: + case INTERVAL_HOUR: + case INTERVAL_HOUR_MINUTE: + case INTERVAL_HOUR_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_MINUTE_SECOND: + case INTERVAL_SECOND: + { + // Calcite represents day/time intervals in milliseconds, despite a default scale of 6. + var totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class)); + var interval = Duration.ofMillis(totalMillis); + + var days = interval.toDays(); + var seconds = interval.minusDays(days).toSeconds(); + var micros = interval.toMillisPart() * 1000; + + return intervalDay(n, (int) days, (int) seconds, micros, 6); + } + + case ROW: + { + List literals = (List) literal.getValue(); + return struct(n, literals.stream().map(this::convert).collect(Collectors.toList())); + } + + case ARRAY: + { + List literals = (List) literal.getValue(); + return list(n, literals.stream().map(this::convert).collect(Collectors.toList())); + } + + default: + throw new UnsupportedOperationException( + String.format( + "Unable to convert the value of %s of type %s to a literal.", + literal, literal.getType().getSqlTypeName())); + } } public static byte[] padRightIfNeeded( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 2bc7ec534..a02e7fbac 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -129,25 +129,30 @@ public Expression visitRangeRef(RexRangeRef rangeRef) { public Expression visitFieldAccess(RexFieldAccess fieldAccess) { SqlKind kind = fieldAccess.getReferenceExpr().getKind(); switch (kind) { - case CORREL_VARIABLE -> { - int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess); - - return FieldReference.newRootStructOuterReference( - fieldAccess.getField().getIndex(), - typeConverter.toSubstrait(fieldAccess.getType()), - stepsOut); - } - case ITEM, INPUT_REF, FIELD_ACCESS -> { - Expression expression = fieldAccess.getReferenceExpr().accept(this); - if (expression instanceof FieldReference) { - FieldReference nestedReference = (FieldReference) expression; - return nestedReference.dereferenceStruct(fieldAccess.getField().getIndex()); - } else { - return FieldReference.newStructReference(fieldAccess.getField().getIndex(), expression); + case CORREL_VARIABLE: + { + int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess); + + return FieldReference.newRootStructOuterReference( + fieldAccess.getField().getIndex(), + typeConverter.toSubstrait(fieldAccess.getType()), + stepsOut); } - } - default -> throw new UnsupportedOperationException( - String.format("RexFieldAccess for SqlKind %s not supported", kind)); + case ITEM: + case INPUT_REF: + case FIELD_ACCESS: + { + Expression expression = fieldAccess.getReferenceExpr().accept(this); + if (expression instanceof FieldReference) { + FieldReference nestedReference = (FieldReference) expression; + return nestedReference.dereferenceStruct(fieldAccess.getField().getIndex()); + } else { + return FieldReference.newStructReference(fieldAccess.getField().getIndex(), expression); + } + } + default: + throw new UnsupportedOperationException( + String.format("RexFieldAccess for SqlKind %s not supported", kind)); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java index c80fd2995..53879fdc7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java @@ -135,7 +135,7 @@ private Optional> getMappedExpressionArguments( .orElse(Optional.empty()); } - protected static class WrappedScalarCall implements GenericCall { + protected static class WrappedScalarCall implements FunctionConverter.GenericCall { private final RexCall delegate; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java index 9cfbf9db5..460c825b7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelFieldCollation.Direction; import org.apache.calcite.rex.RexFieldCollation; public class SortFieldConverter { @@ -10,23 +11,27 @@ public class SortFieldConverter { public static Expression.SortField toSortField( RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) { var expr = rexFieldCollation.left.accept(rexExpressionConverter); - var rexDirection = rexFieldCollation.getDirection(); - Expression.SortDirection direction = - switch (rexDirection) { - case ASCENDING -> rexFieldCollation.getNullDirection() - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.ASC_NULLS_LAST - : Expression.SortDirection.ASC_NULLS_FIRST; - case DESCENDING -> rexFieldCollation.getNullDirection() - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.DESC_NULLS_LAST - : Expression.SortDirection.DESC_NULLS_FIRST; - default -> throw new IllegalArgumentException( - String.format( - "Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!", - rexDirection)); - }; + Expression.SortDirection direction = asSortDirection(rexFieldCollation); return Expression.SortField.builder().expr(expr).direction(direction).build(); } + + private static Expression.SortDirection asSortDirection(RexFieldCollation collation) { + RelFieldCollation.Direction direction = collation.getDirection(); + + if (direction == Direction.ASCENDING) { + return collation.getNullDirection() == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.ASC_NULLS_LAST + : Expression.SortDirection.ASC_NULLS_FIRST; + } + if (direction == Direction.DESCENDING) { + return collation.getNullDirection() == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.DESC_NULLS_LAST + : Expression.SortDirection.DESC_NULLS_FIRST; + } + + throw new IllegalArgumentException( + String.format( + "Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!", direction)); + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java index 3208905ca..ad773136e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java @@ -16,18 +16,25 @@ public static WindowBound toWindowBound(RexWindowBound rexWindowBound) { if (rexWindowBound.isUnbounded()) { return WindowBound.UNBOUNDED; } else { - if (rexWindowBound.getOffset() instanceof RexLiteral literal - && SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { - BigDecimal offset = (BigDecimal) literal.getValue4(); - if (rexWindowBound.isPreceding()) { - return WindowBound.Preceding.of(offset.longValue()); - } - if (rexWindowBound.isFollowing()) { - return WindowBound.Following.of(offset.longValue()); + var node = rexWindowBound.getOffset(); + + if (node instanceof RexLiteral) { + var literal = (RexLiteral) node; + if (SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { + BigDecimal offset = (BigDecimal) literal.getValue4(); + + if (rexWindowBound.isPreceding()) { + return WindowBound.Preceding.of(offset.longValue()); + } + if (rexWindowBound.isFollowing()) { + return WindowBound.Following.of(offset.longValue()); + } + + throw new IllegalStateException( + "window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING"); } - throw new IllegalStateException( - "window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING"); } + throw new IllegalArgumentException( String.format( "substrait only supports integer window offsets. Received: %s", diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java index a1e9bff33..1d988495c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java @@ -99,7 +99,7 @@ public Optional convert( return m.attemptMatch(wrapped, topLevelConverter); } - static class WrappedWindowRelCall implements GenericCall { + static class WrappedWindowRelCall implements FunctionConverter.GenericCall { private final Window.RexWinAggCall winAggCall; private final RexWindowBound lowerBound; private final RexWindowBound upperBound; diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java index 83967195a..3751c0880 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitCreateStatementParser.java @@ -67,10 +67,12 @@ public static List processCreateStatements(String createStatemen SqlNodeList sqlNode = parser.parseStmtList(); for (SqlNode parsed : sqlNode) { - if (!(parsed instanceof SqlCreateTable create)) { + if (!(parsed instanceof SqlCreateTable)) { throw fail("Not a valid CREATE TABLE statement."); } + SqlCreateTable create = (SqlCreateTable) parsed; + if (create.name.names.size() > 1) { throw fail("Only simple table names are allowed.", create.name.getParserPosition()); } @@ -83,7 +85,7 @@ public static List processCreateStatements(String createStatemen List columnTypes = new ArrayList<>(); for (SqlNode node : create.columnList) { - if (!(node instanceof SqlColumnDeclaration col)) { + if (!(node instanceof SqlColumnDeclaration)) { if (node instanceof SqlKeyConstraint) { // key constraints declarations, like primary key declaration, are valid and should not // result in parse exceptions. Ignore the constraint declaration. @@ -93,6 +95,8 @@ public static List processCreateStatements(String createStatemen throw fail("Unexpected column list construction.", node.getParserPosition()); } + SqlColumnDeclaration col = (SqlColumnDeclaration) node; + if (col.name.names.size() != 1) { throw fail("Expected simple column names.", col.name.getParserPosition()); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java index 62357aa3d..32d4b2553 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java @@ -41,15 +41,20 @@ public class AggregationFunctionsTest extends PlanTestBase { // Create the given function call on the given field of the input private Aggregate.Measure functionPicker(Rel input, int field, String fname) { - return switch (fname) { - case "min" -> b.min(input, field); - case "max" -> b.max(input, field); - case "sum" -> b.sum(input, field); - case "sum0" -> b.sum0(input, field); - case "avg" -> b.avg(input, field); - default -> throw new RuntimeException( - String.format("no function is associated with %s", fname)); - }; + switch (fname) { + case "min": + return b.min(input, field); + case "max": + return b.max(input, field); + case "sum": + return b.sum(input, field); + case "sum0": + return b.sum0(input, field); + case "avg": + return b.avg(input, field); + default: + throw new RuntimeException(String.format("no function is associated with %s", fname)); + } } // Create one function call per numeric type column diff --git a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java index e8036a248..5c3f963bf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java @@ -37,10 +37,9 @@ private static Map buildOuterFieldRefMap(RelRoot root) public void lateralJoinQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales CROSS JOIN LATERAL - (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"""; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales CROSS JOIN LATERAL\n" + + " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"; /* the calcite plan for the above query is: LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3]) @@ -69,11 +68,9 @@ public void lateralJoinQuery() throws SqlParseException { public void outerApplyQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales OUTER APPLY - (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk) - """; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales OUTER APPLY\n" + + " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"; RelRoot root = getCalcitePlan(sql); @@ -92,15 +89,14 @@ public void outerApplyQuery() throws SqlParseException { public void nestedApplyJoinQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales CROSS APPLY - ( SELECT i_item_sk - FROM item CROSS APPLY - ( SELECT p_promo_sk - FROM promotion - WHERE p_item_sk = i_item_sk AND p_item_sk = ss_item_sk ) - WHERE item.i_item_sk = store_sales.ss_item_sk )"""; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales CROSS APPLY\n" + + " ( SELECT i_item_sk\n" + + " FROM item CROSS APPLY\n" + + " ( SELECT p_promo_sk\n" + + " FROM promotion\n" + + " WHERE p_item_sk = i_item_sk AND p_item_sk = ss_item_sk )\n" + + " WHERE item.i_item_sk = store_sales.ss_item_sk )"; /* the calcite plan for the above query is: LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3]) @@ -135,10 +131,9 @@ public void nestedApplyJoinQuery() throws SqlParseException { public void crossApplyQuery() throws SqlParseException { String sql; sql = - """ - SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk - FROM store_sales CROSS APPLY - (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"""; + "SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk\n" + + "FROM store_sales CROSS APPLY\n" + + " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"; FeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); SqlToSubstrait s = new SqlToSubstrait(featureBoard); diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java index fc3c4559b..775e1bfe5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java @@ -14,12 +14,16 @@ public abstract class CalciteObjs { final RexBuilder rex = new RexBuilder(type); RelDataType t(SqlTypeName typeName, int... vals) { - return switch (vals.length) { - case 0 -> type.createSqlType(typeName); - case 1 -> type.createSqlType(typeName, vals[0]); - case 2 -> type.createSqlType(typeName, vals[0], vals[1]); - default -> throw new IllegalArgumentException(); - }; + switch (vals.length) { + case 0: + return type.createSqlType(typeName); + case 1: + return type.createSqlType(typeName, vals[0]); + case 2: + return type.createSqlType(typeName, vals[0], vals[1]); + default: + throw new IllegalArgumentException(); + } } RelDataType tN(SqlTypeName typeName, int... vals) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java index 4972fce50..3f7407530 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -65,11 +65,9 @@ void handleInputReferenceSort() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - Collation: [0] - LogicalSort(sort0=[$0], dir0=[ASC]) - LogicalTableScan(table=[[example]]) - """; + "Collation: [0]\n" + + "LogicalSort(sort0=[$0], dir0=[ASC])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -95,13 +93,11 @@ void handleCastExpressionSort() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - LogicalProject(a0=[$0]) - Collation: [1] - LogicalSort(sort0=[$1], dir0=[ASC]) - LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[$0])\n" + + " Collation: [1]\n" + + " LogicalSort(sort0=[$1], dir0=[ASC])\n" + + " LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -127,13 +123,11 @@ void handleCastProjectAndSortWithSortDirection() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - LogicalProject(a0=[CAST($0):INTEGER NOT NULL]) - Collation: [1 DESC-nulls-last] - LogicalSort(sort0=[$1], dir0=[DESC-nulls-last]) - LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[CAST($0):INTEGER NOT NULL])\n" + + " Collation: [1 DESC-nulls-last]\n" + + " LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])\n" + + " LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -159,13 +153,11 @@ void handleCastSortToOriginalType() { b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = - """ - LogicalProject(a0=[$0]) - Collation: [1 DESC-nulls-last] - LogicalSort(sort0=[$1], dir0=[DESC-nulls-last]) - LogicalProject(a=[$0], a0=[$0]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[$0])\n" + + " Collation: [1 DESC-nulls-last]\n" + + " LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])\n" + + " LogicalProject(a=[$0], a0=[$0])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); @@ -194,13 +186,11 @@ void handleComplex2ExpressionSort() { b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.STRING, R.I32)))); String expected = - """ - LogicalProject(a0=[$0], b0=[$1]) - Collation: [2 DESC, 3] - LogicalSort(sort0=[$2], sort1=[$3], dir0=[DESC], dir1=[ASC]) - LogicalProject(a=[$0], b=[$1], a0=[CAST($0):INTEGER NOT NULL], $f3=[+(-($1), 42)]) - LogicalTableScan(table=[[example]]) - """; + "LogicalProject(a0=[$0], b0=[$1])\n" + + " Collation: [2 DESC, 3]\n" + + " LogicalSort(sort0=[$2], sort1=[$3], dir0=[DESC], dir1=[ASC])\n" + + " LogicalProject(a=[$0], b=[$1], a0=[CAST($0):INTEGER NOT NULL], $f3=[+(-($1), 42)])\n" + + " LogicalTableScan(table=[[example]])\n"; RelNode relReturned = substraitToCalcite.convert(rel); var sw = new StringWriter(); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index bb4dcd746..e4e1c4920 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -21,9 +21,7 @@ void preserveNamesFromSql() throws Exception { SqlToSubstrait s = new SqlToSubstrait(); var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); - String query = """ - SELECT "a", "B" FROM foo GROUP BY a, b - """; + String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; List expectedNames = List.of("a", "B"); List calciteRelRoots = s.sqlToRelNode(query, catalogReader); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java index 0b437d789..ea7603a1f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java @@ -85,24 +85,17 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a" - FROM - "nested"."my_table"; - """; + "SELECT\n" + " \"nested\".\"my_table\".\"a\"\n" + "FROM\n" + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -121,29 +114,25 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"."b" - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\".\"b\"\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - struct_field { - field: 0 # b - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # b\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -162,34 +151,30 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"."b"."c" - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\".\"b\".\"c\"\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - struct_field { - field: 0 # b - child: { - struct_field { - field: 0 # c - } - } - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # b\n" + + " child: {\n" + + " struct_field {\n" + + " field: 0 # c\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -207,29 +192,25 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"[1] - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\"[1]\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - list_element { - offset: 1 - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " list_element {\n" + + " offset: 1\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -251,39 +232,35 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table"."a"[1][2][3] - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".\"a\"[1][2][3]\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 1 # a - child { - list_element { - offset: 1 - child { - list_element { - offset: 2 - child { - list_element { - offset: 3 - } - } - } - } - } - } - } - } - root_reference: {} - } - """; + "selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 1 # a\n" + + " child {\n" + + " list_element {\n" + + " offset: 1\n" + + " child {\n" + + " list_element {\n" + + " offset: 2\n" + + " child {\n" + + " list_element {\n" + + " offset: 3\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference: {}\n" + + "}"; test(table, query, expectedExpressionText); } @@ -308,51 +285,47 @@ public RelDataType getRowType(RelDataTypeFactory factory) { }; String query = - """ - SELECT - "nested"."my_table".a.b[2].c['my_map_key'].x - FROM - "nested"."my_table"; - """; + "SELECT\n" + + " \"nested\".\"my_table\".a.b[2].c['my_map_key'].x\n" + + "FROM\n" + + " \"nested\".\"my_table\";"; String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 0 # .a - child { - struct_field { - field: 0 # .b - child { - list_element { - offset: 2 - child { - struct_field { - field: 0 # .c - child { - map_key { - map_key { - string: "my_map_key" # ['my_map_key'] - } - child { - struct_field { - field: 0 # .x - } - } - } - } - } - } - } - } - } - } - } - } - root_reference {} - } - """; + " selection {\n" + + " direct_reference {\n" + + " struct_field {\n" + + " field: 0 # .a\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # .b\n" + + " child {\n" + + " list_element {\n" + + " offset: 2\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # .c\n" + + " child {\n" + + " map_key {\n" + + " map_key {\n" + + " string: \"my_map_key\" # ['my_map_key']\n" + + " }\n" + + " child {\n" + + " struct_field {\n" + + " field: 0 # .x\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " root_reference {}\n" + + "}\n"; test(table, query, expectedExpressionText); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index 48d13612f..0dca1f7d1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -76,34 +76,28 @@ public Optional visit(Cross cross, EmptyVisitationContext context) }; var featureBoard = ImmutableFeatureBoard.builder().build(); - Plan plan1 = - assertProtoPlanRoundrip( - """ - select - c.c_custKey, - o.o_custkey - from - "customer" c cross join - "orders" o - """, - new SqlToSubstrait(featureBoard)); + String query1 = + "select\n" + + " c.c_custKey,\n" + + " o.o_custkey\n" + + "from\n" + + " \"customer\" c cross join\n" + + " \"orders\" o"; + Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait(featureBoard)); plan1 .getRoots() .forEach( t -> t.getInput().accept(crossJoinCountingVisitor, EmptyVisitationContext.INSTANCE)); assertEquals(1, counter[0]); - Plan plan2 = - assertProtoPlanRoundrip( - """ - select - c.c_custKey, - o.o_custkey - from - "customer" c, - "orders" o - """, - new SqlToSubstrait(featureBoard)); + String query2 = + "select\n" + + " c.c_custKey,\n" + + " o.o_custkey\n" + + "from\n" + + " \"customer\" c,\n" + + " \"orders\" o"; + Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait(featureBoard)); plan2 .getRoots() .forEach( diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java b/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java index deac90e87..3e57905e2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java @@ -18,17 +18,7 @@ private SetUtils() {} * @return a sql query */ public static String getSetQuery(Set.SetOp op, boolean multi) { - String opString = - switch (op) { - case MINUS_PRIMARY -> "EXCEPT"; - case MINUS_PRIMARY_ALL -> "EXCEPT ALL"; - case INTERSECTION_MULTISET -> "INTERSECT"; - case INTERSECTION_MULTISET_ALL -> "INTERSECT ALL"; - case UNION_DISTINCT -> "UNION"; - case UNION_ALL -> "UNION ALL"; - default -> throw new UnsupportedOperationException( - "Unknown set operation is not supported"); - }; + String opString = asString(op); StringBuilder query = new StringBuilder(); query.append( @@ -49,6 +39,25 @@ public static String getSetQuery(Set.SetOp op, boolean multi) { } } + private static String asString(Set.SetOp op) { + switch (op) { + case MINUS_PRIMARY: + return "EXCEPT"; + case MINUS_PRIMARY_ALL: + return "EXCEPT ALL"; + case INTERSECTION_MULTISET: + return "INTERSECT"; + case INTERSECTION_MULTISET_ALL: + return "INTERSECT ALL"; + case UNION_DISTINCT: + return "UNION"; + case UNION_ALL: + return "UNION ALL"; + default: + throw new UnsupportedOperationException("Unknown set operation is not supported"); + } + } + // Generate all SetOp types excluding: // * MINUS_MULTISET, INTERSECTION_PRIMARY: do not map to Calcite relations // * UNKNOWN: invalid diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts index 0acec7036..6af7a3397 100644 --- a/spark/build.gradle.kts +++ b/spark/build.gradle.kts @@ -96,15 +96,12 @@ configurations.all { } java { - toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } + toolchain { languageVersion = JavaLanguageVersion.of(17) } withJavadocJar() withSourcesJar() } -tasks.withType() { - targetCompatibility = "" - scalaCompileOptions.additionalParameters = listOf("-release:17") -} +tasks.withType() { scalaCompileOptions.additionalParameters = listOf("-release:17") } var SLF4J_VERSION = properties.get("slf4j.version") var SPARKBUNDLE_VERSION = properties.get("sparkbundle.version") @@ -150,5 +147,6 @@ tasks { test { dependsOn(":core:shadowJar") useJUnitPlatform { includeEngines("scalatest") } + jvmArgs("--add-exports=java.base/sun.nio.ch=ALL-UNNAMED") } }