Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ dependencies {
implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:${JACKSON_VERSION}")
implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${JACKSON_VERSION}")
implementation("com.google.code.findbugs:jsr305:3.0.2")
implementation("com.google.protobuf:protobuf-java-util:${PROTOBUF_VERSION}") {
exclude("com.google.guava", "guava")
.because("Brings in Guava for Android, which we don't want (and breaks multimaps).")
}

antlr("org.antlr:antlr4:${ANTLR_VERSION}")
shadowImplementation("org.antlr:antlr4-runtime:${ANTLR_VERSION}")
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/java/io/substrait/plan/Plan.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.plan;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import io.substrait.proto.AdvancedExtension;
import io.substrait.relation.Rel;
import java.util.List;
Expand Down Expand Up @@ -29,4 +31,26 @@ public static ImmutableRoot.Builder builder() {
return ImmutableRoot.builder();
}
}

/**
* Serializes this plan as protobuf.
*
* @return this plan in protobuf format
*/
public io.substrait.proto.Plan toProto() {
return new PlanProtoConverter().toProto(this);
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking on this some more, it's fine to include this.


/**
* Serializes this plan as a protobuf JSON string.
*
* @return this plan as a protobuf JSON string
*/
public String toJsonString() {
try {
return JsonFormat.printer().includingDefaultValueFields().print(this.toProto());
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException("Can not generate JSON from proto.", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.plan.Plan;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
Expand Down Expand Up @@ -40,7 +39,6 @@ public class TypeExtensionTest {
final SubstraitBuilder b = new SubstraitBuilder(extensionCollection);
Type customType1 = b.userDefinedType(NAMESPACE, "customType1");
Type customType2 = b.userDefinedType(NAMESPACE, "customType2");
final PlanProtoConverter planProtoConverter = new PlanProtoConverter();
final ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensionCollection);

@Test
Expand Down Expand Up @@ -73,7 +71,7 @@ void roundtripCustomType() {
.collect(Collectors.toList()),
b.namedScan(tableName, columnNames, types))));

var protoPlan = planProtoConverter.toProto(plan);
var protoPlan = plan.toProto();
var planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
}
Expand All @@ -99,7 +97,7 @@ void roundtripNumberedAnyTypes() {
b.fieldReference(input, 0)))
.collect(Collectors.toList()),
b.namedScan(tableName, columnNames, types))));
var protoPlan = planProtoConverter.toProto(plan);
var protoPlan = plan.toProto();
var planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
}
Expand Down
3 changes: 1 addition & 2 deletions examples/substrait-spark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ Let's look at the APIs in the `createSubstrait(...)` method to see how it's usin
The `io.substrait.plan.Plan` object is a high-level Substrait POJO representing a plan. This could be used directly or more likely be persisted. protobuf is the canonical serialization form. It's easy to convert this and store in a file

```java
PlanProtoConverter planToProto = new PlanProtoConverter();
byte[] buffer = planToProto.toProto(plan).toByteArray();
byte[] buffer = plan.toProto().toByteArray();
try {
Files.write(Paths.get(ROOT_DIR, "spark_sql_substrait.plan"),buffer);
} catch (IOException e){
Expand Down
16 changes: 10 additions & 6 deletions examples/substrait-spark/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ repositories {
mavenCentral()
}

var SPARK_VERSION = properties.get("spark.version")

dependencies {
implementation("org.apache.spark:spark-core_2.12:3.5.1")
implementation("io.substrait:spark:0.36.0")
implementation("io.substrait:core:0.36.0")
implementation("org.apache.spark:spark-sql_2.12:3.5.1")
implementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}")
implementation(project(":spark"))
implementation(project(":core"))
implementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}")

// For a real Spark application, these would not be required since they would be in the Spark
// server classpath
runtimeOnly("org.apache.spark:spark-core_2.12:3.5.1")
runtimeOnly("org.apache.spark:spark-hive_2.12:3.5.1")
runtimeOnly("org.apache.spark:spark-core_2.12:${SPARK_VERSION}")
runtimeOnly("org.apache.spark:spark-hive_2.12:${SPARK_VERSION}")
}

tasks.jar {
Expand All @@ -30,6 +32,8 @@ tasks.jar {
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
manifest.attributes["Main-Class"] = "io.substrait.examples.App"
from(configurations.runtimeClasspath.get().map({ if (it.isDirectory) it else zipTree(it) }))

dependsOn(":core:shadowJar", ":core:jar")
}

tasks.named<Test>("test") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import static io.substrait.examples.SparkHelper.VEHICLES_CSV;

import io.substrait.examples.util.SubstraitStringify;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.spark.logical.ToSubstraitRel;
import java.io.IOException;
import java.nio.file.Files;
Expand Down Expand Up @@ -73,8 +72,7 @@ public void createSubstrait(LogicalPlan enginePlan) {

SubstraitStringify.explain(plan).forEach(System.out::println);

PlanProtoConverter planToProto = new PlanProtoConverter();
byte[] buffer = planToProto.toProto(plan).toByteArray();
byte[] buffer = plan.toProto().toByteArray();
try {
Files.write(Paths.get(ROOT_DIR, "spark_dataset_substrait.plan"), buffer);
System.out.println("File written to " + Paths.get(ROOT_DIR, "spark_sql_substrait.plan"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import static io.substrait.examples.SparkHelper.VEHICLE_TABLE;

import io.substrait.examples.util.SubstraitStringify;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.spark.logical.ToSubstraitRel;
import java.io.IOException;
import java.nio.file.Files;
Expand Down Expand Up @@ -80,8 +79,7 @@ public void createSubstrait(LogicalPlan enginePlan) {

SubstraitStringify.explain(plan).forEach(System.out::println);

PlanProtoConverter planToProto = new PlanProtoConverter();
byte[] buffer = planToProto.toProto(plan).toByteArray();
byte[] buffer = plan.toProto().toByteArray();
try {
Files.write(Paths.get(ROOT_DIR, "spark_sql_substrait.plan"), buffer);
System.out.println("File written to " + Paths.get(ROOT_DIR, "spark_sql_substrait.plan"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.Expression.DateLiteral;
import io.substrait.expression.Expression.DecimalLiteral;
import io.substrait.expression.Expression.EmptyListLiteral;
import io.substrait.expression.Expression.EmptyMapLiteral;
import io.substrait.expression.Expression.FP32Literal;
import io.substrait.expression.Expression.FP64Literal;
import io.substrait.expression.Expression.FixedBinaryLiteral;
Expand All @@ -16,12 +17,15 @@
import io.substrait.expression.Expression.I8Literal;
import io.substrait.expression.Expression.IfThen;
import io.substrait.expression.Expression.InPredicate;
import io.substrait.expression.Expression.IntervalCompoundLiteral;
import io.substrait.expression.Expression.IntervalDayLiteral;
import io.substrait.expression.Expression.IntervalYearLiteral;
import io.substrait.expression.Expression.ListLiteral;
import io.substrait.expression.Expression.MapLiteral;
import io.substrait.expression.Expression.MultiOrList;
import io.substrait.expression.Expression.NullLiteral;
import io.substrait.expression.Expression.PrecisionTimestampLiteral;
import io.substrait.expression.Expression.PrecisionTimestampTZLiteral;
import io.substrait.expression.Expression.ScalarFunctionInvocation;
import io.substrait.expression.Expression.ScalarSubquery;
import io.substrait.expression.Expression.SetPredicate;
Expand Down Expand Up @@ -124,7 +128,7 @@ public String visit(IntervalYearLiteral expr) throws RuntimeException {

@Override
public String visit(IntervalDayLiteral expr) throws RuntimeException {
return "<IntervalYearLiteral " + expr.seconds() + " " + expr.days() + ">";
return "<IntervalDayLiteral " + expr.seconds() + " " + expr.days() + ">";
}

@Override
Expand Down Expand Up @@ -273,4 +277,32 @@ public String visit(InPredicate expr) throws RuntimeException {

return sb.toString();
}

@Override
public String visit(PrecisionTimestampLiteral expr) throws RuntimeException {
return "<PrecisionTimestampLiteral " + expr.value() + ">";
}

@Override
public String visit(PrecisionTimestampTZLiteral expr) throws RuntimeException {
return "<PrecisionTimestampTZLiteral " + expr.value() + ">";
}

@Override
public String visit(IntervalCompoundLiteral expr) throws RuntimeException {
return "<IntervalCompoundLiteral "
+ expr.months()
+ " "
+ expr.years()
+ " "
+ expr.seconds()
+ " "
+ expr.days()
+ ">";
}

@Override
public String visit(EmptyMapLiteral expr) throws RuntimeException {
return "<EmptyMapLiteral >";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Expand;
import io.substrait.relation.ExtensionLeaf;
import io.substrait.relation.ExtensionMulti;
import io.substrait.relation.ExtensionSingle;
Expand Down Expand Up @@ -336,4 +337,10 @@ public String visit(ConsistentPartitionWindow consistentPartitionWindow) throws
StringBuilder sb = getIndent().append("consistentPartitionWindow:: ");
return getOutdent(sb);
}

@Override
public String visit(Expand expand) throws RuntimeException {
StringBuilder sb = getIndent().append("expand:: ");
return getOutdent(sb);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import io.substrait.type.Type.I32;
import io.substrait.type.Type.I64;
import io.substrait.type.Type.I8;
import io.substrait.type.Type.IntervalCompound;
import io.substrait.type.Type.IntervalDay;
import io.substrait.type.Type.IntervalYear;
import io.substrait.type.Type.ListType;
import io.substrait.type.Type.Map;
import io.substrait.type.Type.PrecisionTime;
import io.substrait.type.Type.Str;
import io.substrait.type.Type.Struct;
import io.substrait.type.Type.Time;
Expand Down Expand Up @@ -172,4 +174,14 @@ public String visit(Map type) throws RuntimeException {
public String visit(UserDefined type) throws RuntimeException {
return type.getClass().getSimpleName();
}

@Override
public String visit(PrecisionTime type) throws RuntimeException {
return type.getClass().getSimpleName();
}

@Override
public String visit(IntervalCompound type) throws RuntimeException {
return type.getClass().getSimpleName();
}
}
8 changes: 3 additions & 5 deletions isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.plan.Plan;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
Expand Down Expand Up @@ -65,7 +64,7 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List<Stri
throws SqlParseException {
io.substrait.proto.Plan protoPlan1 = s.execute(query, creates);
Plan plan = new ProtoPlanConverter(EXTENSION_COLLECTION).from(protoPlan1);
io.substrait.proto.Plan protoPlan2 = new PlanProtoConverter().toProto(plan);
io.substrait.proto.Plan protoPlan2 = plan.toProto();
assertEquals(protoPlan1, protoPlan2);
var rootRels = s.sqlToRelNode(query, creates);
assertEquals(rootRels.size(), plan.getRoots().size());
Expand All @@ -78,9 +77,8 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List<Stri
}

protected void assertPlanRoundtrip(Plan plan) {
io.substrait.proto.Plan protoPlan1 = new PlanProtoConverter().toProto(plan);
io.substrait.proto.Plan protoPlan2 =
new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan1));
io.substrait.proto.Plan protoPlan1 = plan.toProto();
io.substrait.proto.Plan protoPlan2 = new ProtoPlanConverter().from(protoPlan1).toProto();
assertEquals(protoPlan1, protoPlan2);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import io.substrait.isthmus.utils.SetUtils;
import io.substrait.plan.Plan;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.proto.AggregateFunction;
import io.substrait.relation.Cross;
Expand Down Expand Up @@ -60,8 +59,7 @@ public void distinctCount() throws IOException, SqlParseException {
String distinctQuery = "select count(DISTINCT L_ORDERKEY) from lineitem";
io.substrait.proto.Plan protoPlan = getProtoPlan(distinctQuery);
assertAggregateInvocationDistinct(protoPlan);
assertAggregateInvocationDistinct(
new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan)));
assertAggregateInvocationDistinct(new ProtoPlanConverter().from(protoPlan).toProto());
}

@Test
Expand Down
4 changes: 2 additions & 2 deletions spark/src/test/scala/io/substrait/spark/LocalFiles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class LocalFiles extends SharedSparkSession {
val substraitPlan = toSubstrait.convert(sparkPlan)

// Serialize to proto buffer
val bytes = new PlanProtoConverter()
.toProto(substraitPlan)
val bytes = substraitPlan
.toProto()
.toByteArray

// Read it back
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
def assertProtoPlanRoundrip(sql: String): Plan = {
val protoPlan1 = sqlToProtoPlan(sql)
val plan = new ProtoPlanConverter().from(protoPlan1)
val protoPlan2 = new PlanProtoConverter().toProto(plan)
val protoPlan2 = plan.toProto()
assertResult(protoPlan1)(protoPlan2)
assertResult(1)(plan.getRoots.size())
plan
Expand Down Expand Up @@ -136,8 +136,8 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
}

def assertPlanRoundrip(plan: Plan): Unit = {
val protoPlan1 = new PlanProtoConverter().toProto(plan)
val protoPlan2 = new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan1))
val protoPlan1 = plan.toProto()
val protoPlan2 = new ProtoPlanConverter().from(protoPlan1).toProto()
assertResult(protoPlan1)(protoPlan2)
}

Expand Down