diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 0c7ec9199..aff053d64 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -90,6 +90,10 @@ dependencies { implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:${JACKSON_VERSION}") implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${JACKSON_VERSION}") implementation("com.google.code.findbugs:jsr305:3.0.2") + implementation("com.google.protobuf:protobuf-java-util:${PROTOBUF_VERSION}") { + exclude("com.google.guava", "guava") + .because("Brings in Guava for Android, which we don't want (and breaks multimaps).") + } antlr("org.antlr:antlr4:${ANTLR_VERSION}") shadowImplementation("org.antlr:antlr4-runtime:${ANTLR_VERSION}") diff --git a/core/src/main/java/io/substrait/plan/Plan.java b/core/src/main/java/io/substrait/plan/Plan.java index 9d9bd3545..3adc900d8 100644 --- a/core/src/main/java/io/substrait/plan/Plan.java +++ b/core/src/main/java/io/substrait/plan/Plan.java @@ -1,5 +1,7 @@ package io.substrait.plan; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; import io.substrait.proto.AdvancedExtension; import io.substrait.relation.Rel; import java.util.List; @@ -29,4 +31,26 @@ public static ImmutableRoot.Builder builder() { return ImmutableRoot.builder(); } } + + /** + * Serializes this plan as protobuf. + * + * @return this plan in protobuf format + */ + public io.substrait.proto.Plan toProto() { + return new PlanProtoConverter().toProto(this); + } + + /** + * Serializes this plan as a protobuf JSON string. + * + * @return this plan as a protobuf JSON string + */ + public String toJsonString() { + try { + return JsonFormat.printer().includingDefaultValueFields().print(this.toProto()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException("Can not generate JSON from proto.", e); + } + } } diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index f19bbef5b..1d32c83c5 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -5,7 +5,6 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.plan.Plan; -import io.substrait.plan.PlanProtoConverter; import io.substrait.plan.ProtoPlanConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; @@ -40,7 +39,6 @@ public class TypeExtensionTest { final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); Type customType1 = b.userDefinedType(NAMESPACE, "customType1"); Type customType2 = b.userDefinedType(NAMESPACE, "customType2"); - final PlanProtoConverter planProtoConverter = new PlanProtoConverter(); final ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensionCollection); @Test @@ -73,7 +71,7 @@ void roundtripCustomType() { .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - var protoPlan = planProtoConverter.toProto(plan); + var protoPlan = plan.toProto(); var planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } @@ -99,7 +97,7 @@ void roundtripNumberedAnyTypes() { b.fieldReference(input, 0))) .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - var protoPlan = planProtoConverter.toProto(plan); + var protoPlan = plan.toProto(); var planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } diff --git a/examples/substrait-spark/README.md b/examples/substrait-spark/README.md index d9885dd83..f60c9d01a 100644 --- a/examples/substrait-spark/README.md +++ b/examples/substrait-spark/README.md @@ -263,8 +263,7 @@ Let's look at the APIs in the `createSubstrait(...)` method to see how it's usin The `io.substrait.plan.Plan` object is a high-level Substrait POJO representing a plan. This could be used directly or more likely be persisted. protobuf is the canonical serialization form. It's easy to convert this and store in a file ```java - PlanProtoConverter planToProto = new PlanProtoConverter(); - byte[] buffer = planToProto.toProto(plan).toByteArray(); + byte[] buffer = plan.toProto().toByteArray(); try { Files.write(Paths.get(ROOT_DIR, "spark_sql_substrait.plan"),buffer); } catch (IOException e){ diff --git a/examples/substrait-spark/build.gradle.kts b/examples/substrait-spark/build.gradle.kts index 212f2b11c..326fdf584 100644 --- a/examples/substrait-spark/build.gradle.kts +++ b/examples/substrait-spark/build.gradle.kts @@ -9,16 +9,18 @@ repositories { mavenCentral() } +var SPARK_VERSION = properties.get("spark.version") + dependencies { - implementation("org.apache.spark:spark-core_2.12:3.5.1") - implementation("io.substrait:spark:0.36.0") - implementation("io.substrait:core:0.36.0") - implementation("org.apache.spark:spark-sql_2.12:3.5.1") + implementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}") + implementation(project(":spark")) + implementation(project(":core")) + implementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}") // For a real Spark application, these would not be required since they would be in the Spark // server classpath - runtimeOnly("org.apache.spark:spark-core_2.12:3.5.1") - runtimeOnly("org.apache.spark:spark-hive_2.12:3.5.1") + runtimeOnly("org.apache.spark:spark-core_2.12:${SPARK_VERSION}") + runtimeOnly("org.apache.spark:spark-hive_2.12:${SPARK_VERSION}") } tasks.jar { @@ -30,6 +32,8 @@ tasks.jar { duplicatesStrategy = DuplicatesStrategy.EXCLUDE manifest.attributes["Main-Class"] = "io.substrait.examples.App" from(configurations.runtimeClasspath.get().map({ if (it.isDirectory) it else zipTree(it) })) + + dependsOn(":core:shadowJar", ":core:jar") } tasks.named("test") { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java index 81de54b0b..32bbcc9de 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java @@ -5,7 +5,6 @@ import static io.substrait.examples.SparkHelper.VEHICLES_CSV; import io.substrait.examples.util.SubstraitStringify; -import io.substrait.plan.PlanProtoConverter; import io.substrait.spark.logical.ToSubstraitRel; import java.io.IOException; import java.nio.file.Files; @@ -73,8 +72,7 @@ public void createSubstrait(LogicalPlan enginePlan) { SubstraitStringify.explain(plan).forEach(System.out::println); - PlanProtoConverter planToProto = new PlanProtoConverter(); - byte[] buffer = planToProto.toProto(plan).toByteArray(); + byte[] buffer = plan.toProto().toByteArray(); try { Files.write(Paths.get(ROOT_DIR, "spark_dataset_substrait.plan"), buffer); System.out.println("File written to " + Paths.get(ROOT_DIR, "spark_sql_substrait.plan")); diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java index fc4c8fec4..81f79f017 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java @@ -7,7 +7,6 @@ import static io.substrait.examples.SparkHelper.VEHICLE_TABLE; import io.substrait.examples.util.SubstraitStringify; -import io.substrait.plan.PlanProtoConverter; import io.substrait.spark.logical.ToSubstraitRel; import java.io.IOException; import java.nio.file.Files; @@ -80,8 +79,7 @@ public void createSubstrait(LogicalPlan enginePlan) { SubstraitStringify.explain(plan).forEach(System.out::println); - PlanProtoConverter planToProto = new PlanProtoConverter(); - byte[] buffer = planToProto.toProto(plan).toByteArray(); + byte[] buffer = plan.toProto().toByteArray(); try { Files.write(Paths.get(ROOT_DIR, "spark_sql_substrait.plan"), buffer); System.out.println("File written to " + Paths.get(ROOT_DIR, "spark_sql_substrait.plan")); diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index e8630200e..44c8bab71 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -6,6 +6,7 @@ import io.substrait.expression.Expression.DateLiteral; import io.substrait.expression.Expression.DecimalLiteral; import io.substrait.expression.Expression.EmptyListLiteral; +import io.substrait.expression.Expression.EmptyMapLiteral; import io.substrait.expression.Expression.FP32Literal; import io.substrait.expression.Expression.FP64Literal; import io.substrait.expression.Expression.FixedBinaryLiteral; @@ -16,12 +17,15 @@ import io.substrait.expression.Expression.I8Literal; import io.substrait.expression.Expression.IfThen; import io.substrait.expression.Expression.InPredicate; +import io.substrait.expression.Expression.IntervalCompoundLiteral; import io.substrait.expression.Expression.IntervalDayLiteral; import io.substrait.expression.Expression.IntervalYearLiteral; import io.substrait.expression.Expression.ListLiteral; import io.substrait.expression.Expression.MapLiteral; import io.substrait.expression.Expression.MultiOrList; import io.substrait.expression.Expression.NullLiteral; +import io.substrait.expression.Expression.PrecisionTimestampLiteral; +import io.substrait.expression.Expression.PrecisionTimestampTZLiteral; import io.substrait.expression.Expression.ScalarFunctionInvocation; import io.substrait.expression.Expression.ScalarSubquery; import io.substrait.expression.Expression.SetPredicate; @@ -124,7 +128,7 @@ public String visit(IntervalYearLiteral expr) throws RuntimeException { @Override public String visit(IntervalDayLiteral expr) throws RuntimeException { - return ""; + return ""; } @Override @@ -273,4 +277,32 @@ public String visit(InPredicate expr) throws RuntimeException { return sb.toString(); } + + @Override + public String visit(PrecisionTimestampLiteral expr) throws RuntimeException { + return ""; + } + + @Override + public String visit(PrecisionTimestampTZLiteral expr) throws RuntimeException { + return ""; + } + + @Override + public String visit(IntervalCompoundLiteral expr) throws RuntimeException { + return ""; + } + + @Override + public String visit(EmptyMapLiteral expr) throws RuntimeException { + return ""; + } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index f650e8f62..973d1aaef 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -4,6 +4,7 @@ import io.substrait.relation.ConsistentPartitionWindow; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; +import io.substrait.relation.Expand; import io.substrait.relation.ExtensionLeaf; import io.substrait.relation.ExtensionMulti; import io.substrait.relation.ExtensionSingle; @@ -336,4 +337,10 @@ public String visit(ConsistentPartitionWindow consistentPartitionWindow) throws StringBuilder sb = getIndent().append("consistentPartitionWindow:: "); return getOutdent(sb); } + + @Override + public String visit(Expand expand) throws RuntimeException { + StringBuilder sb = getIndent().append("expand:: "); + return getOutdent(sb); + } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java index 796e9ca6b..f935dcd52 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java @@ -13,10 +13,12 @@ import io.substrait.type.Type.I32; import io.substrait.type.Type.I64; import io.substrait.type.Type.I8; +import io.substrait.type.Type.IntervalCompound; import io.substrait.type.Type.IntervalDay; import io.substrait.type.Type.IntervalYear; import io.substrait.type.Type.ListType; import io.substrait.type.Type.Map; +import io.substrait.type.Type.PrecisionTime; import io.substrait.type.Type.Str; import io.substrait.type.Type.Struct; import io.substrait.type.Type.Time; @@ -172,4 +174,14 @@ public String visit(Map type) throws RuntimeException { public String visit(UserDefined type) throws RuntimeException { return type.getClass().getSimpleName(); } + + @Override + public String visit(PrecisionTime type) throws RuntimeException { + return type.getClass().getSimpleName(); + } + + @Override + public String visit(IntervalCompound type) throws RuntimeException { + return type.getClass().getSimpleName(); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 7c7770172..01576850a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -12,7 +12,6 @@ import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan; -import io.substrait.plan.PlanProtoConverter; import io.substrait.plan.ProtoPlanConverter; import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; @@ -65,7 +64,7 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List def assertProtoPlanRoundrip(sql: String): Plan = { val protoPlan1 = sqlToProtoPlan(sql) val plan = new ProtoPlanConverter().from(protoPlan1) - val protoPlan2 = new PlanProtoConverter().toProto(plan) + val protoPlan2 = plan.toProto() assertResult(protoPlan1)(protoPlan2) assertResult(1)(plan.getRoots.size()) plan @@ -136,8 +136,8 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => } def assertPlanRoundrip(plan: Plan): Unit = { - val protoPlan1 = new PlanProtoConverter().toProto(plan) - val protoPlan2 = new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan1)) + val protoPlan1 = plan.toProto() + val protoPlan2 = new ProtoPlanConverter().from(protoPlan1).toProto() assertResult(protoPlan1)(protoPlan2) }