From baa5ec057a67f3902daca876edb8e0cb51196fd2 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 3 Apr 2025 13:13:11 -0700 Subject: [PATCH 1/4] test(isthmus): verify that names are retained between conversions --- .../substrait/isthmus/NameRoundtripTest.java | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java new file mode 100644 index 000000000..490bb1536 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -0,0 +1,34 @@ +package io.substrait.isthmus; + +import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import org.junit.jupiter.api.Test; + +public class NameRoundtripTest extends PlanTestBase { + + @Test + void outputNamesShouldBeConsistent() throws Exception { + List creates = List.of("CREATE TABLE foo(a BIGINT, b BIGINT)"); + + SqlToSubstrait s = new SqlToSubstrait(); + var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); + + String query = """ + SELECT "a", "B" FROM foo GROUP BY a, b + """; + List expectedNames = List.of("a", "B"); + + List calciteRelRoots = s.sqlToRelNode(query, creates); + assertEquals(1, calciteRelRoots.size()); + + org.apache.calcite.rel.RelRoot calciteRelRoot = calciteRelRoots.get(0); + assertEquals(expectedNames, calciteRelRoot.validatedRowType.getFieldNames()); + + io.substrait.relation.Rel substraitRel = + SubstraitRelVisitor.convert(calciteRelRoot, EXTENSION_COLLECTION); + org.apache.calcite.rel.RelNode relNode = substraitToCalcite.convert(substraitRel); + assertEquals(expectedNames, relNode.getRowType().getFieldNames()); + } +} From 54776701328a7c4bd90dad6e23b385bdfb1fb8df Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 3 Apr 2025 14:50:43 -0700 Subject: [PATCH 2/4] feat: convert Calcite RelRoot to Substrait Plan.Root in SubstraitRelVisitor RelRoots must be converted to Plan.Roots in order to ensure that names are handled correctly. BREAKING CHANGE: converting a Calcite RelRoot no longer produces a Substrait Rel --- .../substrait/relation/ProtoRelConverter.java | 5 ++ .../substrait/relation/RelProtoConverter.java | 9 ++ .../io/substrait/isthmus/SqlToSubstrait.java | 12 +-- .../isthmus/SubstraitRelVisitor.java | 37 +++++--- .../substrait/isthmus/NameRoundtripTest.java | 14 ++-- .../io/substrait/isthmus/PlanTestBase.java | 84 +++++++++---------- .../substrait/isthmus/Substrait2SqlTest.java | 17 ++-- 7 files changed, 100 insertions(+), 78 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 0d35277c4..e939b1adf 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -8,6 +8,7 @@ import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; import io.substrait.hint.Hint; +import io.substrait.plan.Plan; import io.substrait.proto.AggregateRel; import io.substrait.proto.ConsistentPartitionWindowRel; import io.substrait.proto.CrossRel; @@ -61,6 +62,10 @@ public ProtoRelConverter(ExtensionLookup lookup, SimpleExtension.ExtensionCollec this.protoTypeConverter = new ProtoTypeConverter(lookup, extensions); } + public Plan.Root from(io.substrait.proto.RelRoot rel) { + return Plan.Root.builder().input(from(rel.getInput())).addAllNames(rel.getNamesList()).build(); + } + public Rel from(io.substrait.proto.Rel rel) { var relType = rel.getRelTypeCase(); switch (relType) { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index f30778e03..484982031 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ExpressionProtoConverter.BoundConverter; import io.substrait.extension.ExtensionCollector; +import io.substrait.plan.Plan; import io.substrait.proto.AggregateFunction; import io.substrait.proto.AggregateRel; import io.substrait.proto.ConsistentPartitionWindowRel; @@ -24,6 +25,7 @@ import io.substrait.proto.ReadRel; import io.substrait.proto.Rel; import io.substrait.proto.RelCommon; +import io.substrait.proto.RelRoot; import io.substrait.proto.SetRel; import io.substrait.proto.SortField; import io.substrait.proto.SortRel; @@ -59,6 +61,13 @@ public TypeProtoConverter getTypeProtoConverter() { return this.typeProtoConverter; } + public io.substrait.proto.RelRoot toProto(Plan.Root relRoot) { + return RelRoot.newBuilder() + .setInput(toProto(relRoot.getInput())) + .addAllNames(relRoot.getNames()) + .build(); + } + public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) { return rel.accept(this); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 8d55b0682..6fb64c6f7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -66,15 +66,9 @@ private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogRea plan.addRelations( PlanRel.newBuilder() .setRoot( - io.substrait.proto.RelRoot.newBuilder() - .setInput( - SubstraitRelVisitor.convert( - root, EXTENSION_COLLECTION, featureBoard) - .accept(relProtoConverter)) - .addAllNames( - TypeConverter.DEFAULT - .toNamedStruct(root.validatedRowType) - .names()))); + relProtoConverter.toProto( + SubstraitRelVisitor.convert( + root, EXTENSION_COLLECTION, featureBoard)))); }); functionCollector.addExtensionsToPlan(plan); return plan.build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 8e65c59bb..f88f62b0c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -10,6 +10,7 @@ import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.plan.Plan; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; @@ -379,20 +380,36 @@ public List apply(List inputs) { .collect(java.util.stream.Collectors.toList()); } - public static Rel convert(RelRoot root, SimpleExtension.ExtensionCollection extensions) { - return convert(root.rel, extensions, FEATURES_DEFAULT); + public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) { + return convert(relRoot, extensions, FEATURES_DEFAULT); } - public static Rel convert( - RelRoot root, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - return convert(root.rel, extensions, features); + public static Plan.Root convert( + RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + SubstraitRelVisitor visitor = + new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features); + visitor.popFieldAccessDepthMap(relRoot.rel); + Rel rel = visitor.apply(relRoot.project()); + + // Avoid using the names from relRoot.validatedRowType because if there are + // nested types (i.e ROW, MAP, etc) the typeConverter will pad names correctly + List names = visitor.typeConverter.toNamedStruct(relRoot.validatedRowType).names(); + return Plan.Root.builder().input(rel).names(names).build(); } - private static Rel convert( - RelNode rel, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) { + SubstraitRelVisitor visitor = + new SubstraitRelVisitor( + relNode.getCluster().getTypeFactory(), extensions, FEATURES_DEFAULT); + visitor.popFieldAccessDepthMap(relNode); + return visitor.apply(relNode); + } + + public static Rel convert( + RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { SubstraitRelVisitor visitor = - new SubstraitRelVisitor(rel.getCluster().getTypeFactory(), extensions, features); - visitor.popFieldAccessDepthMap(rel); - return visitor.apply(rel); + new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features); + visitor.popFieldAccessDepthMap(relNode); + return visitor.apply(relNode); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index 490bb1536..0ba07b39f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -23,12 +23,14 @@ void outputNamesShouldBeConsistent() throws Exception { List calciteRelRoots = s.sqlToRelNode(query, creates); assertEquals(1, calciteRelRoots.size()); - org.apache.calcite.rel.RelRoot calciteRelRoot = calciteRelRoots.get(0); - assertEquals(expectedNames, calciteRelRoot.validatedRowType.getFieldNames()); + org.apache.calcite.rel.RelRoot calciteRelRoot1 = calciteRelRoots.get(0); + assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames()); - io.substrait.relation.Rel substraitRel = - SubstraitRelVisitor.convert(calciteRelRoot, EXTENSION_COLLECTION); - org.apache.calcite.rel.RelNode relNode = substraitToCalcite.convert(substraitRel); - assertEquals(expectedNames, relNode.getRowType().getFieldNames()); + io.substrait.plan.Plan.Root substraitRelRoot = + SubstraitRelVisitor.convert(calciteRelRoot1, EXTENSION_COLLECTION); + assertEquals(expectedNames, substraitRelRoot.getNames()); + + org.apache.calcite.rel.RelRoot calciteRelRoot2 = substraitToCalcite.convert(substraitRelRoot); + assertEquals(expectedNames, calciteRelRoot2.validatedRowType.getFieldNames()); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 95b2c9f0e..84d1987af 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -20,7 +20,6 @@ import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.calcite.rel.RelNode; @@ -28,7 +27,6 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.tools.RelBuilder; import org.junit.jupiter.api.Assertions; @@ -72,8 +70,9 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List assertSqlSubstraitRelRoundTrip(String query) throws Exception { + protected RelRoot assertSqlSubstraitRelRoundTrip(String query) throws Exception { return assertSqlSubstraitRelRoundTrip(query, tpchSchemaCreateStatements()); } - protected List assertSqlSubstraitRelRoundTrip(String query, List creates) + protected RelRoot assertSqlSubstraitRelRoundTrip(String query, List creates) throws Exception { // sql <--> substrait round trip test. // Assert (sql -> calcite -> substrait) and (sql -> substrait -> calcite -> substrait) are same. // Return list of sql -> Substrait rel -> Calcite rel. - List relNodeList = new ArrayList<>(); var substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); SqlToSubstrait s = new SqlToSubstrait(); // 1. SQL -> Calcite RelRoot - for (RelRoot relRoot : s.sqlToRelNode(query, creates)) { - // 2. Calcite RelRoot -> Substrait Rel - Rel pojo1 = SubstraitRelVisitor.convert(relRoot, EXTENSION_COLLECTION); + List relRoots = s.sqlToRelNode(query, creates); + assertEquals(1, relRoots.size()); + RelRoot relRoot1 = relRoots.get(0); - // 3. Substrait Rel -> Calcite RelNode - RelNode relNode = substraitToCalcite.convert(pojo1); + // 2. Calcite RelRoot -> Substrait Rel + Plan.Root pojo1 = SubstraitRelVisitor.convert(relRoot1, EXTENSION_COLLECTION); - relNodeList.add(relNode); + // 3. Substrait Rel -> Calcite RelNode + RelRoot relRoot2 = substraitToCalcite.convert(pojo1); - // 4. Calcite RelNode -> Substrait Rel - Rel pojo2 = - SubstraitRelVisitor.convert(RelRoot.of(relNode, SqlKind.SELECT), EXTENSION_COLLECTION); + // 4. Calcite RelNode -> Substrait Rel + Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, EXTENSION_COLLECTION); - Assertions.assertEquals(pojo1, pojo2); - } - return relNodeList; + Assertions.assertEquals(pojo1, pojo2); + return relRoot2; } @Beta @@ -140,37 +137,36 @@ protected void assertFullRoundTrip(String query) throws IOException, SqlParseExc protected void assertFullRoundTrip(String sqlQuery, List createStatements) throws SqlParseException { SqlToSubstrait sqlConverter = new SqlToSubstrait(); - List relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements); + ExtensionCollector extensionCollector = new ExtensionCollector(); - for (RelRoot calcite1 : relRoots) { - var extensionCollector = new ExtensionCollector(); + // SQL -> Calcite 1 + List relRoots = sqlConverter.sqlToRelNode(sqlQuery, createStatements); + assertEquals(1, relRoots.size()); + RelRoot calcite1 = relRoots.get(0); - // Calcite 1 -> Substrait POJO 1 - io.substrait.relation.Rel pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION); + // Calcite 1 -> Substrait POJO 1 + Plan.Root pojo1 = SubstraitRelVisitor.convert(calcite1, EXTENSION_COLLECTION); - // Substrait POJO 1 -> Substrait Proto - io.substrait.proto.Rel proto = new RelProtoConverter(extensionCollector).toProto(pojo1); + // Substrait POJO 1 -> Substrait Proto + io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1); - // Substrait Proto -> Substrait Pojo 2 - io.substrait.relation.Rel pojo2 = - new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto); + // Substrait Proto -> Substrait Pojo 2 + Plan.Root pojo2 = new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto); - // Verify that POJOs are the same - assertEquals(pojo1, pojo2); + // Verify that POJOs are the same + assertEquals(pojo1, pojo2); - // Substrait POJO 2 -> Calcite 2 - RelNode calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); - // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to - // do so - assertNotNull(calcite2); + // Substrait POJO 2 -> Calcite 2 + RelRoot calcite2 = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); + // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to + // do so + assertNotNull(calcite2); - // Calcite 2 -> Substrait POJO 3 - io.substrait.relation.Rel pojo3 = - SubstraitRelVisitor.convert(RelRoot.of(calcite2, calcite1.kind), EXTENSION_COLLECTION); + // Calcite 2 -> Substrait POJO 3 + Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite2, EXTENSION_COLLECTION); - // Verify that POJOs are the same - assertEquals(pojo1, pojo3); - } + // Verify that POJOs are the same + assertEquals(pojo1, pojo3); } /** @@ -198,9 +194,7 @@ protected void assertFullRoundTrip(Rel pojo1) { RelNode calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); // Calcite -> Substrait POJO 3 - io.substrait.relation.Rel pojo3 = - // SqlKind.SELECT is used because the majority of our tests are SELECT queries - SubstraitRelVisitor.convert(RelRoot.of(calcite, SqlKind.SELECT), EXTENSION_COLLECTION); + io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION); // Verify that POJOs are the same assertEquals(pojo1, pojo3); diff --git a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java index 2a8162caa..096e5c953 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java @@ -1,11 +1,12 @@ package io.substrait.isthmus; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import io.substrait.isthmus.utils.SetUtils; import io.substrait.relation.Set; -import java.util.List; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.junit.jupiter.api.Test; @@ -142,15 +143,15 @@ public void tpch_q1_variant() throws Exception { @Test public void simpleTestApproxCountDistinct() throws Exception { String query = "select approx_count_distinct(l_tax) from lineitem"; - List relNodeList = assertSqlSubstraitRelRoundTrip(query); + RelRoot relRoot = assertSqlSubstraitRelRoundTrip(query); + RelNode relNode = relRoot.project(); // Assert converted Calcite RelNode has `approx_count_distinct` - RelNode relNode = relNodeList.get(0); - assertTrue(relNode instanceof LogicalAggregate); + assertInstanceOf(LogicalAggregate.class, relNode); LogicalAggregate aggregate = (LogicalAggregate) relNode; - assertTrue( - aggregate.getAggCallList().get(0).getAggregation() - == SqlStdOperatorTable.APPROX_COUNT_DISTINCT); + assertEquals( + SqlStdOperatorTable.APPROX_COUNT_DISTINCT, + aggregate.getAggCallList().get(0).getAggregation()); } @Test From 7da43afb9c1adbbd8bf0105b14d59434d1cce471 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 3 Apr 2025 15:09:41 -0700 Subject: [PATCH 3/4] test: additional tests for name preservation --- .../substrait/isthmus/NameRoundtripTest.java | 17 +++++++++- .../io/substrait/isthmus/PlanTestBase.java | 32 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index 0ba07b39f..f8d6e9897 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -3,13 +3,15 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; +import io.substrait.plan.Plan; +import io.substrait.relation.NamedScan; import java.util.List; import org.junit.jupiter.api.Test; public class NameRoundtripTest extends PlanTestBase { @Test - void outputNamesShouldBeConsistent() throws Exception { + void preserveNamesFromSql() throws Exception { List creates = List.of("CREATE TABLE foo(a BIGINT, b BIGINT)"); SqlToSubstrait s = new SqlToSubstrait(); @@ -33,4 +35,17 @@ void outputNamesShouldBeConsistent() throws Exception { org.apache.calcite.rel.RelRoot calciteRelRoot2 = substraitToCalcite.convert(substraitRelRoot); assertEquals(expectedNames, calciteRelRoot2.validatedRowType.getFieldNames()); } + + @Test + void preserveNamesFromSubstrait() { + NamedScan rel = + substraitBuilder.namedScan( + List.of("foo"), + List.of("i64", "struct", "struct0", "struct1"), + List.of(R.I64, R.struct(R.FP64, R.STRING))); + + Plan.Root planRoot = + Plan.Root.builder().input(rel).names(List.of("i", "s", "s0", "s1")).build(); + assertFullRoundTrip(planRoot); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 84d1987af..7c7770172 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -178,6 +178,7 @@ protected void assertFullRoundTrip(String sqlQuery, List createStatement * */ protected void assertFullRoundTrip(Rel pojo1) { + // TODO: reuse the Plan.Root based assertFullRoundTrip by generating names var extensionCollector = new ExtensionCollector(); // Substrait POJO 1 -> Substrait Proto @@ -200,6 +201,37 @@ protected void assertFullRoundTrip(Rel pojo1) { assertEquals(pojo1, pojo3); } + /** + * Verifies that the given POJO can be converted: + * + *
    + *
  • From POJO to Proto and back + *
  • From POJO to Calcite and back + *
+ */ + protected void assertFullRoundTrip(Plan.Root pojo1) { + var extensionCollector = new ExtensionCollector(); + + // Substrait POJO 1 -> Substrait Proto + io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1); + + // Substrait Proto -> Substrait Pojo 2 + io.substrait.plan.Plan.Root pojo2 = + new ProtoRelConverter(extensionCollector, EXTENSION_COLLECTION).from(proto); + + // Verify that POJOs are the same + assertEquals(pojo1, pojo2); + + // Substrait POJO 2 -> Calcite + RelRoot calcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo2); + + // Calcite -> Substrait POJO 3 + io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, EXTENSION_COLLECTION); + + // Verify that POJOs are the same + assertEquals(pojo1, pojo3); + } + protected void assertRowMatch(RelDataType actual, Type... expected) { assertRowMatch(actual, Arrays.asList(expected)); } From 0e7ad0354ffe5743f9cf4728e9fe96c8c73ef828 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 3 Apr 2025 15:50:51 -0700 Subject: [PATCH 4/4] refactor: minor SubstraitRelVisitor cleanup --- .../main/java/io/substrait/isthmus/SubstraitRelVisitor.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index f88f62b0c..400e99722 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -398,11 +398,7 @@ public static Plan.Root convert( } public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) { - SubstraitRelVisitor visitor = - new SubstraitRelVisitor( - relNode.getCluster().getTypeFactory(), extensions, FEATURES_DEFAULT); - visitor.popFieldAccessDepthMap(relNode); - return visitor.apply(relNode); + return convert(relNode, extensions, FEATURES_DEFAULT); } public static Rel convert(