Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ dependencies {
implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}")
annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}")
compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}")
annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2")
compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2")
}

val submodulesUpdate by
Expand All @@ -41,20 +39,10 @@ allprojects {
repositories { mavenCentral() }

tasks.configureEach<Test> {
val javaToolchains = project.extensions.getByType<JavaToolchainService>()
useJUnitPlatform()
javaLauncher.set(javaToolchains.launcherFor { languageVersion.set(JavaLanguageVersion.of(11)) })
testLogging { exceptionFormat = TestExceptionFormat.FULL }
}
tasks.withType<JavaCompile> {
sourceCompatibility = "17"
if (project.name != "core") {
options.release.set(11)
} else {
options.release.set(8)
}
dependsOn(submodulesUpdate)
}
tasks.withType<JavaCompile> { dependsOn(submodulesUpdate) }

group = "io.substrait"
version = "${version}"
Expand Down
12 changes: 5 additions & 7 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ dependencies {
implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}")
annotationProcessor("org.immutables:value:${IMMUTABLES_VERSION}")
compileOnly("org.immutables:value-annotations:${IMMUTABLES_VERSION}")
annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2")
compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2")
}

configurations[JavaPlugin.API_CONFIGURATION_NAME].let { apiConfiguration ->
Expand Down Expand Up @@ -233,12 +231,12 @@ tasks {
jar { manifest { from("build/generated/sources/manifest/META-INF/MANIFEST.MF") } }
}

// Set the release instead of using a Java 8 toolchain since ANTLR requires Java 11+ to run
tasks.withType<JavaCompile>().configureEach { options.release = 8 }

java {
toolchain {
languageVersion.set(JavaLanguageVersion.of(17))
withJavadocJar()
withSourcesJar()
}
withJavadocJar()
withSourcesJar()
}

configurations { runtimeClasspath { resolutionStrategy.activateDependencyLocking() } }
Expand Down
68 changes: 41 additions & 27 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.substrait.dsl;

import com.github.bsideup.jabel.Desugar;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.Cast;
Expand Down Expand Up @@ -87,8 +86,8 @@ private Aggregate aggregate(
Function<Rel, List<Aggregate.Measure>> measuresFn,
Optional<Rel.Remap> remap,
Rel input) {
var groupings = groupingsFn.apply(input);
var measures = measuresFn.apply(input);
List<Aggregate.Grouping> groupings = groupingsFn.apply(input);
List<Aggregate.Measure> measures = measuresFn.apply(input);
return Aggregate.builder()
.groupings(groupings)
.measures(measures)
Expand Down Expand Up @@ -147,12 +146,27 @@ public Filter filter(Function<Rel, Expression> conditionFn, Rel.Remap remap, Rel

private Filter filter(
Function<Rel, Expression> conditionFn, Optional<Rel.Remap> remap, Rel input) {
var condition = conditionFn.apply(input);
Expression condition = conditionFn.apply(input);
return Filter.builder().input(input).condition(condition).remap(remap).build();
}

@Desugar
public record JoinInput(Rel left, Rel right) {}
public static final class JoinInput {
private final Rel left;
private final Rel right;

JoinInput(Rel left, Rel right) {
this.left = left;
this.right = right;
}

public Rel left() {
return left;
}

public Rel right() {
return right;
}
}

public Join innerJoin(Function<JoinInput, Expression> conditionFn, Rel left, Rel right) {
return join(conditionFn, Join.JoinType.INNER, left, right);
Expand Down Expand Up @@ -183,7 +197,7 @@ private Join join(
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
var condition = conditionFn.apply(new JoinInput(left, right));
Expression condition = conditionFn.apply(new JoinInput(left, right));
return Join.builder()
.left(left)
.right(right)
Expand Down Expand Up @@ -263,7 +277,7 @@ private NestedLoopJoin nestedLoopJoin(
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
var condition = conditionFn.apply(new JoinInput(left, right));
Expression condition = conditionFn.apply(new JoinInput(left, right));
return NestedLoopJoin.builder()
.left(left)
.right(right)
Expand Down Expand Up @@ -291,8 +305,8 @@ private NamedScan namedScan(
Iterable<String> columnNames,
Iterable<Type> types,
Optional<Rel.Remap> remap) {
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
var namedStruct = NamedStruct.of(columnNames, struct);
Type.Struct struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
NamedStruct namedStruct = NamedStruct.of(columnNames, struct);
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
}

Expand All @@ -315,7 +329,7 @@ private Project project(
Function<Rel, Iterable<? extends Expression>> expressionsFn,
Optional<Rel.Remap> remap,
Rel input) {
var expressions = expressionsFn.apply(input);
Iterable<? extends Expression> expressions = expressionsFn.apply(input);
return Project.builder().input(input).expressions(expressions).remap(remap).build();
}

Expand All @@ -332,7 +346,7 @@ private Expand expand(
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
Optional<Rel.Remap> remap,
Rel input) {
var fields = fieldsFn.apply(input);
Iterable<? extends Expand.ExpandField> fields = fieldsFn.apply(input);
return Expand.builder().input(input).fields(fields).remap(remap).build();
}

Expand Down Expand Up @@ -363,7 +377,7 @@ private Sort sort(
Function<Rel, Iterable<? extends Expression.SortField>> sortFieldFn,
Optional<Rel.Remap> remap,
Rel input) {
var condition = sortFieldFn.apply(input);
Iterable<? extends Expression.SortField> condition = sortFieldFn.apply(input);
return Sort.builder().input(input).sortFields(condition).remap(remap).build();
}

Expand Down Expand Up @@ -465,7 +479,7 @@ public Switch switchExpression(

public AggregateFunctionInvocation aggregateFn(
String namespace, String key, Type outputType, Expression... args) {
var declaration =
SimpleExtension.AggregateFunctionVariant declaration =
extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return AggregateFunctionInvocation.builder()
.arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList()))
Expand All @@ -477,7 +491,7 @@ public AggregateFunctionInvocation aggregateFn(
}

public Aggregate.Grouping grouping(Rel input, int... indexes) {
var columns = fieldReferences(input, indexes);
List<FieldReference> columns = fieldReferences(input, indexes);
return Aggregate.Grouping.builder().addAllExpressions(columns).build();
}

Expand All @@ -486,7 +500,7 @@ public Aggregate.Grouping grouping(Expression... expressions) {
}

public Aggregate.Measure count(Rel input, int field) {
var declaration =
SimpleExtension.AggregateFunctionVariant declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any"));
Expand Down Expand Up @@ -563,7 +577,7 @@ public Aggregate.Measure sum0(Expression expr) {
private Aggregate.Measure singleArgumentArithmeticAggregate(
Expression expr, String functionName, Type outputType) {
String typeString = ToTypeString.apply(expr.getType());
var declaration =
SimpleExtension.AggregateFunctionVariant declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
Expand All @@ -585,7 +599,7 @@ private Aggregate.Measure singleArgumentArithmeticAggregate(

public Expression.ScalarFunctionInvocation negate(Expression expr) {
// output type of negate is the same as the input type
var outputType = expr.getType();
Type outputType = expr.getType();
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("negate:%s", ToTypeString.apply(outputType)),
Expand All @@ -611,12 +625,12 @@ public Expression.ScalarFunctionInvocation divide(Expression left, Expression ri

private Expression.ScalarFunctionInvocation arithmeticFunction(
String fname, Expression left, Expression right) {
var leftTypeStr = ToTypeString.apply(left.getType());
var rightTypeStr = ToTypeString.apply(right.getType());
var key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr);
String leftTypeStr = ToTypeString.apply(left.getType());
String rightTypeStr = ToTypeString.apply(right.getType());
String key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr);

var isOutputNullable = left.getType().nullable() || right.getType().nullable();
var outputType = left.getType();
boolean isOutputNullable = left.getType().nullable() || right.getType().nullable();
Type outputType = left.getType();
outputType =
isOutputNullable
? TypeCreator.asNullable(outputType)
Expand All @@ -633,14 +647,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig
public Expression.ScalarFunctionInvocation or(Expression... args) {
// If any arg is nullable, the output of or is potentially nullable
// For example: false or null = null
var isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
var outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
}

public Expression.ScalarFunctionInvocation scalarFn(
String namespace, String key, Type outputType, FunctionArg... args) {
var declaration =
SimpleExtension.ScalarFunctionVariant declaration =
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
Expand All @@ -659,7 +673,7 @@ public Expression.WindowFunctionInvocation windowFn(
WindowBound lowerBound,
WindowBound upperBound,
Expression... args) {
var declaration =
SimpleExtension.WindowFunctionVariant declaration =
extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ public io.substrait.proto.Expression.WindowFunction.BoundsType toProto() {

public static WindowBoundsType fromProto(
io.substrait.proto.Expression.WindowFunction.BoundsType proto) {
for (var v : values()) {
for (WindowBoundsType v : values()) {
if (v.proto == proto) {
return v;
}
Expand Down Expand Up @@ -984,7 +984,7 @@ public io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp toProto()

public static PredicateOp fromProto(
io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto) {
for (var v : values()) {
for (PredicateOp v : values()) {
if (v.proto == proto) {
return v;
}
Expand All @@ -1010,7 +1010,7 @@ public io.substrait.proto.AggregateFunction.AggregationInvocation toProto() {
}

public static AggregationInvocation fromProto(AggregateFunction.AggregationInvocation proto) {
for (var v : values()) {
for (AggregationInvocation v : values()) {
if (v.proto == proto) {
return v;
}
Expand Down Expand Up @@ -1041,7 +1041,7 @@ public io.substrait.proto.AggregationPhase toProto() {
}

public static AggregationPhase fromProto(io.substrait.proto.AggregationPhase proto) {
for (var v : values()) {
for (AggregationPhase v : values()) {
if (v.proto == proto) {
return v;
}
Expand Down Expand Up @@ -1069,7 +1069,7 @@ public io.substrait.proto.SortField.SortDirection toProto() {
}

public static SortDirection fromProto(io.substrait.proto.SortField.SortDirection proto) {
for (var v : values()) {
for (SortDirection v : values()) {
if (v.proto == proto) {
return v;
}
Expand Down Expand Up @@ -1097,7 +1097,7 @@ public io.substrait.proto.Expression.Cast.FailureBehavior toProto() {

public static FailureBehavior fromProto(
io.substrait.proto.Expression.Cast.FailureBehavior proto) {
for (var v : values()) {
for (FailureBehavior v : values()) {
if (v.proto == proto) {
return v;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.substrait.type.Type;
import io.substrait.util.DecimalUtil;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
Expand Down Expand Up @@ -89,7 +90,7 @@ public static Expression.TimestampLiteral timestamp(boolean nullable, long value
*/
@Deprecated
public static Expression.TimestampLiteral timestamp(boolean nullable, LocalDateTime value) {
var epochMicro =
long epochMicro =
TimeUnit.SECONDS.toMicros(value.toEpochSecond(ZoneOffset.UTC))
+ TimeUnit.NANOSECONDS.toMicros(value.toLocalTime().getNano());
return timestamp(nullable, epochMicro);
Expand Down Expand Up @@ -127,7 +128,7 @@ public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, long v
*/
@Deprecated
public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, Instant value) {
var epochMicro =
long epochMicro =
TimeUnit.SECONDS.toMicros(value.getEpochSecond())
+ TimeUnit.NANOSECONDS.toMicros(value.getNano());
return timestampTZ(nullable, epochMicro);
Expand Down Expand Up @@ -195,7 +196,7 @@ public static Expression.IntervalCompoundLiteral intervalCompound(
}

public static Expression.UUIDLiteral uuid(boolean nullable, ByteString uuid) {
var bb = uuid.asReadOnlyByteBuffer();
ByteBuffer bb = uuid.asReadOnlyByteBuffer();
return Expression.UUIDLiteral.builder()
.nullable(nullable)
.value(new UUID(bb.getLong(), bb.getLong()))
Expand Down Expand Up @@ -237,7 +238,7 @@ public static Expression.DecimalLiteral decimal(

public static Expression.DecimalLiteral decimal(
boolean nullable, BigDecimal value, int precision, int scale) {
var twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16);
byte[] twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16);

return Expression.DecimalLiteral.builder()
.nullable(nullable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private static FieldReference of(
Collections.reverse(segments);
for (int i = 0; i < segments.size(); i++) {
if (i == 0) {
var last = segments.get(0);
ReferenceSegment last = segments.get(0);
reference =
struct == null ? last.constructOnExpression(expression) : last.constructOnRoot(struct);
} else {
Expand Down
Loading
Loading