Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.hint.Hint;
import io.substrait.plan.Plan;
import io.substrait.proto.AggregateRel;
import io.substrait.proto.ConsistentPartitionWindowRel;
import io.substrait.proto.CrossRel;
Expand Down Expand Up @@ -61,6 +62,10 @@ public ProtoRelConverter(ExtensionLookup lookup, SimpleExtension.ExtensionCollec
this.protoTypeConverter = new ProtoTypeConverter(lookup, extensions);
}

public Plan.Root from(io.substrait.proto.RelRoot rel) {
return Plan.Root.builder().input(from(rel.getInput())).addAllNames(rel.getNamesList()).build();
}

public Rel from(io.substrait.proto.Rel rel) {
var relType = rel.getRelTypeCase();
switch (relType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.expression.proto.ExpressionProtoConverter.BoundConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.plan.Plan;
import io.substrait.proto.AggregateFunction;
import io.substrait.proto.AggregateRel;
import io.substrait.proto.ConsistentPartitionWindowRel;
Expand All @@ -24,6 +25,7 @@
import io.substrait.proto.ReadRel;
import io.substrait.proto.Rel;
import io.substrait.proto.RelCommon;
import io.substrait.proto.RelRoot;
import io.substrait.proto.SetRel;
import io.substrait.proto.SortField;
import io.substrait.proto.SortRel;
Expand Down Expand Up @@ -59,6 +61,13 @@ public TypeProtoConverter getTypeProtoConverter() {
return this.typeProtoConverter;
}

public io.substrait.proto.RelRoot toProto(Plan.Root relRoot) {
return RelRoot.newBuilder()
.setInput(toProto(relRoot.getInput()))
.addAllNames(relRoot.getNames())
.build();
}

public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) {
return rel.accept(this);
}
Expand Down
12 changes: 3 additions & 9 deletions isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,9 @@ private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogRea
plan.addRelations(
PlanRel.newBuilder()
.setRoot(
io.substrait.proto.RelRoot.newBuilder()
.setInput(
SubstraitRelVisitor.convert(
root, EXTENSION_COLLECTION, featureBoard)
.accept(relProtoConverter))
.addAllNames(
TypeConverter.DEFAULT
.toNamedStruct(root.validatedRowType)
.names())));
relProtoConverter.toProto(
SubstraitRelVisitor.convert(
root, EXTENSION_COLLECTION, featureBoard))));
});
functionCollector.addExtensionsToPlan(plan);
return plan.build();
Expand Down
33 changes: 23 additions & 10 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
Expand Down Expand Up @@ -379,20 +380,32 @@ public List<Rel> apply(List<RelNode> inputs) {
.collect(java.util.stream.Collectors.toList());
}

public static Rel convert(RelRoot root, SimpleExtension.ExtensionCollection extensions) {
return convert(root.rel, extensions, FEATURES_DEFAULT);
public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) {
return convert(relRoot, extensions, FEATURES_DEFAULT);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking API change.

A CalciteRelRoot corresponds to POJO Plan.Root, not a POJO Rel.

}

public static Rel convert(
RelRoot root, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
return convert(root.rel, extensions, features);
public static Plan.Root convert(
RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
SubstraitRelVisitor visitor =
new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features);
visitor.popFieldAccessDepthMap(relRoot.rel);
Rel rel = visitor.apply(relRoot.project());

// Avoid using the names from relRoot.validatedRowType because if there are
// nested types (i.e ROW, MAP, etc) the typeConverter will pad names correctly
List<String> names = visitor.typeConverter.toNamedStruct(relRoot.validatedRowType).names();
return Plan.Root.builder().input(rel).names(names).build();
}

private static Rel convert(
RelNode rel, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) {
return convert(relNode, extensions, FEATURES_DEFAULT);
}

public static Rel convert(
RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
SubstraitRelVisitor visitor =
new SubstraitRelVisitor(rel.getCluster().getTypeFactory(), extensions, features);
visitor.popFieldAccessDepthMap(rel);
return visitor.apply(rel);
new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features);
visitor.popFieldAccessDepthMap(relNode);
return visitor.apply(relNode);
}
}
51 changes: 51 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.plan.Plan;
import io.substrait.relation.NamedScan;
import java.util.List;
import org.junit.jupiter.api.Test;

public class NameRoundtripTest extends PlanTestBase {

@Test
void preserveNamesFromSql() throws Exception {
List<String> creates = List.of("CREATE TABLE foo(a BIGINT, b BIGINT)");

SqlToSubstrait s = new SqlToSubstrait();
var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory);

String query = """
SELECT "a", "B" FROM foo GROUP BY a, b
""";
List<String> expectedNames = List.of("a", "B");

List<org.apache.calcite.rel.RelRoot> calciteRelRoots = s.sqlToRelNode(query, creates);
assertEquals(1, calciteRelRoots.size());

org.apache.calcite.rel.RelRoot calciteRelRoot1 = calciteRelRoots.get(0);
assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames());

io.substrait.plan.Plan.Root substraitRelRoot =
SubstraitRelVisitor.convert(calciteRelRoot1, EXTENSION_COLLECTION);
assertEquals(expectedNames, substraitRelRoot.getNames());

org.apache.calcite.rel.RelRoot calciteRelRoot2 = substraitToCalcite.convert(substraitRelRoot);
assertEquals(expectedNames, calciteRelRoot2.validatedRowType.getFieldNames());
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping to be able to use to the new Plan.Root based assertFullRoundTrip to test this, but it doesn't quite work: #371

}

@Test
void preserveNamesFromSubstrait() {
NamedScan rel =
substraitBuilder.namedScan(
List.of("foo"),
List.of("i64", "struct", "struct0", "struct1"),
List.of(R.I64, R.struct(R.FP64, R.STRING)));

Plan.Root planRoot =
Plan.Root.builder().input(rel).names(List.of("i", "s", "s0", "s1")).build();
assertFullRoundTrip(planRoot);
}
}
116 changes: 71 additions & 45 deletions isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.tools.RelBuilder;
import org.junit.jupiter.api.Assertions;
Expand Down Expand Up @@ -72,8 +70,9 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List<Stri
var rootRels = s.sqlToRelNode(query, creates);
assertEquals(rootRels.size(), plan.getRoots().size());
for (int i = 0; i < rootRels.size(); i++) {
var rootRel = SubstraitRelVisitor.convert(rootRels.get(i), EXTENSION_COLLECTION);
assertEquals(rootRel.getRecordType(), plan.getRoots().get(i).getInput().getRecordType());
Plan.Root rootRel = SubstraitRelVisitor.convert(rootRels.get(i), EXTENSION_COLLECTION);
assertEquals(
rootRel.getInput().getRecordType(), plan.getRoots().get(i).getInput().getRecordType());
}
return plan;
}
Expand All @@ -85,38 +84,36 @@ protected void assertPlanRoundtrip(Plan plan) {
assertEquals(protoPlan1, protoPlan2);
}

protected List<RelNode> assertSqlSubstraitRelRoundTrip(String query) throws Exception {
protected RelRoot assertSqlSubstraitRelRoundTrip(String query) throws Exception {
return assertSqlSubstraitRelRoundTrip(query, tpchSchemaCreateStatements());
}

protected List<RelNode> assertSqlSubstraitRelRoundTrip(String query, List<String> creates)
protected RelRoot assertSqlSubstraitRelRoundTrip(String query, List<String> creates)
throws Exception {
// sql <--> substrait round trip test.
// Assert (sql -> calcite -> substrait) and (sql -> substrait -> calcite -> substrait) are same.
// Return list of sql -> Substrait rel -> Calcite rel.
List<RelNode> relNodeList = new ArrayList<>();

var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory);

SqlToSubstrait s = new SqlToSubstrait();

// 1. SQL -> Calcite RelRoot
for (RelRoot relRoot : s.sqlToRelNode(query, creates)) {
// 2. Calcite RelRoot -> Substrait Rel
Rel pojo1 = SubstraitRelVisitor.convert(relRoot, EXTENSION_COLLECTION);
List<RelRoot> relRoots = s.sqlToRelNode(query, creates);
assertEquals(1, relRoots.size());
RelRoot relRoot1 = relRoots.get(0);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice we only ever generate a single RelRoot. Capturing this here and removing the looping to simplify our test code.


// 3. Substrait Rel -> Calcite RelNode
RelNode relNode = substraitToCalcite.convert(pojo1);
// 2. Calcite RelRoot -> Substrait Rel
Plan.Root pojo1 = SubstraitRelVisitor.convert(relRoot1, EXTENSION_COLLECTION);

relNodeList.add(relNode);
// 3. Substrait Rel -> Calcite RelNode
RelRoot relRoot2 = substraitToCalcite.convert(pojo1);

// 4. Calcite RelNode -> Substrait Rel
Rel pojo2 =
SubstraitRelVisitor.convert(RelRoot.of(relNode, SqlKind.SELECT), EXTENSION_COLLECTION);
// 4. Calcite RelNode -> Substrait Rel
Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, EXTENSION_COLLECTION);

Assertions.assertEquals(pojo1, pojo2);
}
return relNodeList;
Assertions.assertEquals(pojo1, pojo2);
return relRoot2;
}

@Beta
Expand All @@ -140,37 +137,36 @@ protected void assertFullRoundTrip(String query) throws IOException, SqlParseExc
protected void assertFullRoundTrip(String sqlQuery, List<String> createStatements)
throws SqlParseException {
SqlToSubstrait sqlConverter = new SqlToSubstrait();
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements);
ExtensionCollector extensionCollector = new ExtensionCollector();

for (RelRoot calcite1 : relRoots) {
var extensionCollector = new ExtensionCollector();
// SQL -> Calcite 1
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements);
assertEquals(1, relRoots.size());
RelRoot calcite1 = relRoots.get(0);

// Calcite 1 -> Substrait POJO 1
io.substrait.relation.Rel pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION);
// Calcite 1 -> Substrait POJO 1
Plan.Root pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION);

// Substrait POJO 1 -> Substrait Proto
io.substrait.proto.Rel proto = new RelProtoConverter(extensionCollector).toProto(pojo1);
// Substrait POJO 1 -> Substrait Proto
io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1);

// Substrait Proto -> Substrait Pojo 2
io.substrait.relation.Rel pojo2 =
new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto);
// Substrait Proto -> Substrait Pojo 2
Plan.Root pojo2 = new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto);

// Verify that POJOs are the same
assertEquals(pojo1, pojo2);
// Verify that POJOs are the same
assertEquals(pojo1, pojo2);

// Substrait POJO 2 -> Calcite 2
RelNode calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);
// It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to
// do so
assertNotNull(calcite2);
// Substrait POJO 2 -> Calcite 2
RelRoot calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);
// It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to
// do so
assertNotNull(calcite2);

// Calcite 2 -> Substrait POJO 3
io.substrait.relation.Rel pojo3 =
SubstraitRelVisitor.convert(RelRoot.of(calcite2, calcite1.kind), EXTENSION_COLLECTION);
// Calcite 2 -> Substrait POJO 3
Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite2, EXTENSION_COLLECTION);

// Verify that POJOs are the same
assertEquals(pojo1, pojo3);
}
// Verify that POJOs are the same
assertEquals(pojo1, pojo3);
}

/**
Expand All @@ -182,6 +178,7 @@ protected void assertFullRoundTrip(String sqlQuery, List<String> createStatement
* </ul>
*/
protected void assertFullRoundTrip(Rel pojo1) {
// TODO: reuse the Plan.Root based assertFullRoundTrip by generating names
var extensionCollector = new ExtensionCollector();

// Substrait POJO 1 -> Substrait Proto
Expand All @@ -198,9 +195,38 @@ protected void assertFullRoundTrip(Rel pojo1) {
RelNode calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);

// Calcite -> Substrait POJO 3
io.substrait.relation.Rel pojo3 =
// SqlKind.SELECT is used because the majority of our tests are SELECT queries
SubstraitRelVisitor.convert(RelRoot.of(calcite, SqlKind.SELECT), EXTENSION_COLLECTION);
io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION);

// Verify that POJOs are the same
assertEquals(pojo1, pojo3);
}

/**
* Verifies that the given POJO can be converted:
*
* <ul>
* <li>From POJO to Proto and back
* <li>From POJO to Calcite and back
* </ul>
*/
protected void assertFullRoundTrip(Plan.Root pojo1) {
var extensionCollector = new ExtensionCollector();

// Substrait POJO 1 -> Substrait Proto
io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1);

// Substrait Proto -> Substrait Pojo 2
io.substrait.plan.Plan.Root pojo2 =
new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto);

// Verify that POJOs are the same
assertEquals(pojo1, pojo2);

// Substrait POJO 2 -> Calcite
RelRoot calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2);

// Calcite -> Substrait POJO 3
io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION);

// Verify that POJOs are the same
assertEquals(pojo1, pojo3);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

import io.substrait.isthmus.utils.SetUtils;
import io.substrait.relation.Set;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -142,15 +143,15 @@ public void tpch_q1_variant() throws Exception {
@Test
public void simpleTestApproxCountDistinct() throws Exception {
String query = "select approx_count_distinct(l_tax) from lineitem";
List<RelNode> relNodeList = assertSqlSubstraitRelRoundTrip(query);
RelRoot relRoot = assertSqlSubstraitRelRoundTrip(query);
RelNode relNode = relRoot.project();

// Assert converted Calcite RelNode has `approx_count_distinct`
RelNode relNode = relNodeList.get(0);
assertTrue(relNode instanceof LogicalAggregate);
assertInstanceOf(LogicalAggregate.class, relNode);
LogicalAggregate aggregate = (LogicalAggregate) relNode;
assertTrue(
aggregate.getAggCallList().get(0).getAggregation()
== SqlStdOperatorTable.APPROX_COUNT_DISTINCT);
assertEquals(
SqlStdOperatorTable.APPROX_COUNT_DISTINCT,
aggregate.getAggCallList().get(0).getAggregation());
}

@Test
Expand Down
Loading