From ffc4494f66b1a31ee3b3d192da2ff2491f189d26 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Wed, 25 Jun 2025 10:16:50 +0200 Subject: [PATCH 1/4] fix(isthmus): handle subqueries with outer field references fixes https://github.com/substrait-io/substrait-java/issues/382 Signed-off-by: Niels Pardon --- .../io/substrait/dsl/SubstraitBuilder.java | 20 + .../substrait/expression/FieldReference.java | 8 +- .../isthmus/SubstraitRelNodeConverter.java | 66 +++- .../expression/ExpressionRexConverter.java | 52 ++- .../io/substrait/isthmus/TpcdsQueryTest.java | 5 +- .../io/substrait/isthmus/TpchQueryTest.java | 7 +- .../expression/SubqueryConversionTest.java | 371 ++++++++++++++++++ 7 files changed, 514 insertions(+), 15 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 9de3a0faf..2bd4e8c5f 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -6,6 +6,7 @@ import io.substrait.expression.Expression.FailureBehavior; import io.substrait.expression.Expression.IfClause; import io.substrait.expression.Expression.IfThen; +import io.substrait.expression.Expression.PredicateOp; import io.substrait.expression.Expression.SingleOrList; import io.substrait.expression.Expression.Switch; import io.substrait.expression.Expression.SwitchClause; @@ -644,6 +645,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right); } + public Expression.ScalarFunctionInvocation and(Expression... args) { + // If any arg is nullable, the output of and is potentially nullable + // For example: false and null = null + boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); + Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; + return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "and:bool", outputType, args); + } + public Expression.ScalarFunctionInvocation or(Expression... args) { // If any arg is nullable, the output of or is potentially nullable // For example: false or null = null @@ -706,4 +715,15 @@ public Plan plan(Plan.Root root) { public Rel.Remap remap(Integer... fields) { return Rel.Remap.of(Arrays.asList(fields)); } + + public Expression scalarSubquery(Rel input, Type type) { + return Expression.ScalarSubquery.builder().input(input).type(type).build(); + } + + public Expression exists(Rel project) { + return Expression.SetPredicate.builder() + .tuples(project) + .predicateOp(PredicateOp.PREDICATE_OP_EXISTS) + .build(); + } } diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 192dc2578..722699153 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -36,7 +36,13 @@ public R accept( } public boolean isSimpleRootReference() { - return segments().size() == 1 && !inputExpression().isPresent(); + return segments().size() == 1 + && !inputExpression().isPresent() + && !outerReferenceStepsOut().isPresent(); + } + + public boolean isOuterReference() { + return outerReferenceStepsOut().orElse(0) > 0; } public FieldReference dereferenceStruct(int index) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 1e69e4d6d..e4279794b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -34,9 +34,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.OptionalLong; +import java.util.Stack; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -49,6 +51,7 @@ import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.logical.LogicalTableModify; @@ -156,8 +159,11 @@ public static RelNode convert( @Override public RelNode visit(Filter filter, Context context) throws RuntimeException { RelNode input = filter.getInput().accept(this, context); + context.pushParentRelNodes(input); RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context); - RelNode node = relBuilder.push(input).filter(filterCondition).build(); + RelNode node = + relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build(); + context.popParentRelNodes(); return applyRemap(node, filter.getRemap()); } @@ -183,6 +189,8 @@ public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeExcepti @Override public RelNode visit(Project project, Context context) throws RuntimeException { RelNode child = project.getInput().accept(this, context); + context.pushParentRelNodes(child); + Stream directOutputs = IntStream.range(0, child.getRowType().getFieldCount()) .mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex)); @@ -193,7 +201,12 @@ public RelNode visit(Project project, Context context) throws RuntimeException { List rexExprs = Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList()); - RelNode node = relBuilder.push(child).project(rexExprs).build(); + RelNode node = + relBuilder + .push(child) + .project(rexExprs, List.of(), false, context.popCorrelationIds()) + .build(); + context.popParentRelNodes(); return applyRemap(node, project.getRemap()); } @@ -211,12 +224,19 @@ public RelNode visit(Cross cross, Context context) throws RuntimeException { public RelNode visit(Join join, Context context) throws RuntimeException { RelNode left = join.getLeft().accept(this, context); RelNode right = join.getRight().accept(this, context); + context.pushParentRelNodes(left, right); RexNode condition = join.getCondition() .map(c -> c.accept(expressionRexConverter, context)) .orElse(relBuilder.literal(true)); JoinRelType joinType = asJoinRelType(join); - RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build(); + RelNode node = + relBuilder + .push(left) + .push(right) + .join(joinType, condition, context.popCorrelationIds()) + .build(); + context.popParentRelNodes(); return applyRemap(node, join.getRemap()); } @@ -627,8 +647,48 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) { } public static class Context implements VisitationContext { + protected final Stack parentRelations = new Stack<>(); + + protected final Stack> correlationIds = new Stack<>(); + + private int subqueryDepth; + public static Context newContext() { return new Context(); } + + public void pushParentRelNodes(final RelNode... inputs) { + parentRelations.push(inputs); + this.correlationIds.push(new HashSet<>()); + } + + public void popParentRelNodes() { + parentRelations.pop(); + } + + public RelNode[] getParentRelation(final Integer stepsOut) { + return this.parentRelations.get(subqueryDepth - stepsOut); + } + + public java.util.Set popCorrelationIds() { + return correlationIds.pop(); + } + + public void addCorrelationId(final int stepsOut, final CorrelationId correlationId) { + final int index = subqueryDepth - stepsOut; + this.correlationIds.get(index).add(correlationId); + } + + public void incrementSubqueryDepth() { + this.subqueryDepth++; + } + + public void decrementSubqueryDepth() { + this.subqueryDepth--; + } + } + + public RelBuilder getRelBuilder() { + return relBuilder; } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index a042db611..ea69fba40 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -1,6 +1,9 @@ package io.substrait.isthmus.expression; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Range; +import com.google.common.collect.RangeMap; +import com.google.common.collect.TreeRangeMap; import io.substrait.expression.AbstractExpressionVisitor; import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; @@ -33,6 +36,7 @@ import java.util.stream.Stream; import org.apache.calcite.avatica.util.ByteString; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; @@ -513,7 +517,9 @@ private boolean isDistinct(Expression.WindowFunctionInvocation expr) { public RexNode visit(Expression.InPredicate expr, Context context) throws RuntimeException { List needles = expr.needles().stream().map(e -> e.accept(this, context)).collect(Collectors.toList()); + context.incrementSubqueryDepth(); RelNode rel = expr.haystack().accept(relNodeConverter, context); + context.decrementSubqueryDepth(); return RexSubQuery.in(rel, ImmutableList.copyOf(needles)); } @@ -589,13 +595,47 @@ public RexNode visit(Expression.Cast expr, Context context) throws RuntimeExcept @Override public RexNode visit(FieldReference expr, Context context) throws RuntimeException { if (expr.isSimpleRootReference()) { - ReferenceSegment segment = expr.segments().get(0); + final ReferenceSegment segment = expr.segments().get(0); - RexInputRef rexInputRef; + final RexInputRef rexInputRef; if (segment instanceof FieldReference.StructField) { - FieldReference.StructField f = (FieldReference.StructField) segment; + FieldReference.StructField field = (FieldReference.StructField) segment; rexInputRef = - new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); + new RexInputRef(field.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } + + return rexInputRef; + } else if (expr.isOuterReference()) { + final ReferenceSegment segment = expr.segments().get(0); + + final RexNode rexInputRef; + if (segment instanceof FieldReference.StructField) { + FieldReference.StructField field = (FieldReference.StructField) segment; + + final RelNode[] parents = context.getParentRelation(expr.outerReferenceStepsOut().get()); + final RangeMap fieldRangeMap = TreeRangeMap.create(); + + int begin = 0; + int fieldOffset = field.offset(); + for (final RelNode parent : parents) { + final int end = begin + parent.getRowType().getFieldCount(); + final Range range = Range.closedOpen(begin, end); + fieldRangeMap.put(range, parent); + if (range.contains(field.offset())) { + fieldOffset = fieldOffset - range.lowerEndpoint(); + } + begin = end; + } + + CorrelationId correlationId = relNodeConverter.getRelBuilder().getCluster().createCorrel(); + context.addCorrelationId(expr.outerReferenceStepsOut().get(), correlationId); + rexInputRef = + rexBuilder.makeFieldAccess( + rexBuilder.makeCorrel( + fieldRangeMap.get(field.offset()).getRowType(), correlationId), + fieldOffset); } else { throw new IllegalArgumentException("Unhandled type: " + segment); } @@ -646,13 +686,17 @@ public RexNode visitEnumArg( @Override public RexNode visit(ScalarSubquery expr, Context context) throws RuntimeException { + context.incrementSubqueryDepth(); RelNode inputRelnode = expr.input().accept(relNodeConverter, context); + context.decrementSubqueryDepth(); return RexSubQuery.scalar(inputRelnode); } @Override public RexNode visit(SetPredicate expr, Context context) throws RuntimeException { + context.incrementSubqueryDepth(); RelNode inputRelnode = expr.tuples().accept(relNodeConverter, context); + context.decrementSubqueryDepth(); switch (expr.predicateOp()) { case PREDICATE_OP_EXISTS: return RexSubQuery.exists(inputRelnode); diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java index cbba343fd..64c65ca74 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java @@ -1,6 +1,7 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.proto.Plan; import java.io.IOException; @@ -13,7 +14,7 @@ /** TPC-DS test to convert SQL to Substrait and then convert those plans back to SQL. */ public class TpcdsQueryTest extends PlanTestBase { private static final Set toSubstraitExclusions = Set.of(9, 27, 36, 70, 86); - private static final Set fromSubstraitExclusions = Set.of(6, 8, 67); + private static final Set fromSubstraitExclusions = Set.of(1, 8, 30, 67, 81); static IntStream testCases() { return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n)); @@ -32,6 +33,8 @@ public void testQuery(int query) throws IOException { if (!fromSubstraitExclusions.contains(query)) { assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); + } else { + assertThrows(Throwable.class, () -> toSql(plan), "Substrait to SQL"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java index d52a682cf..0a8b0e8af 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java @@ -4,7 +4,6 @@ import io.substrait.proto.Plan; import java.io.IOException; -import java.util.Set; import java.util.stream.IntStream; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.params.ParameterizedTest; @@ -12,8 +11,6 @@ /** TPC-H test to convert SQL to Substrait and then convert those plans back to SQL. */ public class TpchQueryTest extends PlanTestBase { - private static final Set fromSubstraitExclusions = Set.of(17); - static IntStream testCases() { return IntStream.rangeClosed(1, 22); } @@ -29,9 +26,7 @@ public void testQuery(int query) throws IOException { Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait"); - if (!fromSubstraitExclusions.contains(query)) { - assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); - } + assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); } private Plan toSubstraitPlan(String sql) throws SqlParseException { diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java new file mode 100644 index 000000000..0fc5a4616 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java @@ -0,0 +1,371 @@ +package io.substrait.isthmus.expression; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.expression.FieldReference; +import io.substrait.isthmus.PlanTestBase; +import io.substrait.isthmus.SubstraitToCalcite; +import io.substrait.isthmus.sql.SubstraitSqlDialect; +import io.substrait.relation.Rel; +import io.substrait.relation.Rel.Remap; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.apache.calcite.rel.RelNode; +import org.junit.jupiter.api.Test; + +class SubqueryConversionTest extends PlanTestBase { + protected final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + + @Test + void testOuterFieldReferenceOneStep() { + /* + * SELECT + * orders.o_orderkey, + * (SELECT customer.c_nationkey FROM customer WHERE customer.c_custkey = orders.o_custkey) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + // orders.o_orderkey + substraitBuilder.fieldReference(input, 0), + // (SELECT customer.c_nationkey FROM customer WHERE customer.c_custkey = + // orders.o_custkey) + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input2 -> + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference(input2, 0), + // orders.o_custkey + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 1)), + substraitBuilder.namedScan( + List.of("customer"), + List.of("c_custkey", "c_nationkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)))), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2, 3)), + substraitBuilder.namedScan( + List.of("orders"), + List.of("o_orderkey", "o_custkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + + final RelNode calciteRel = converter.convert(root); + + // LogicalFilter has field reference with $cor0 correlation variable + // outer LogicalProject has variablesSet containing $cor0 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.o_custkey)])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\"))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } + + @Test + void testOuterFieldReferenceTwoSteps() { + /* + * SELECT + * orders.o_orderkey, + * ( + * SELECT + * n_name + * FROM nation + * WHERE n_nationkey = + * ( + * SELECT + * customer.c_nationkey + * FROM customer + * WHERE + * customer.c_custkey = orders.o_custkey + * ) + * ) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.fieldReference(input, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(2)), + substraitBuilder.filter( + input2 -> + substraitBuilder.equal( + substraitBuilder.fieldReference(input2, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input3 -> + List.of( + substraitBuilder.fieldReference(input3, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input3 -> + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference( + input3, 0), + // orders.o_custkey + FieldReference + .newRootStructOuterReference( + 1, + TypeCreator.REQUIRED.I64, + 2)), + substraitBuilder.namedScan( + List.of("customer"), + List.of("c_custkey", "c_nationkey"), + List.of( + TypeCreator.REQUIRED.I64, + TypeCreator.REQUIRED.I64)))), + TypeCreator.NULLABLE.I64)), + substraitBuilder.namedScan( + List.of("nation"), + List.of("n_nationkey", "n_name"), + List.of( + TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)))), + TypeCreator.NULLABLE.STRING)), + Remap.of(List.of(2, 3)), + substraitBuilder.namedScan( + List.of("orders"), + List.of("o_orderkey", "o_custkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + + final RelNode calciteRel = converter.convert(root); + + // most inner LogicalFilter has field reference with $cor0 correlation variable + // most outer LogicalProject has variablesSet containing $cor0 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(n_name0=[$1])\n" + + " LogicalFilter(condition=[=($0, $SCALAR_QUERY({\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.o_custkey)])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "}))])\n" + + " LogicalTableScan(table=[[nation]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"n_name\" AS \"n_name0\"\n" + + "FROM \"nation\"\n" + + "WHERE \"n_nationkey\" = (((SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\")))))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } + + @Test + void testInPredicateOuterFieldReference() { + /* + * SELECT + * orders.o_orderkey, + * ( + * SELECT + * n_name + * FROM nation + * WHERE n_nationkey IN + * ( + * SELECT + * customer.c_nationkey + * FROM customer + * WHERE + * customer.c_custkey = orders.o_custkey + * ) + * ) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.fieldReference(input, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(2)), + substraitBuilder.filter( + input2 -> + substraitBuilder.inPredicate( + substraitBuilder.project( + input3 -> + List.of(substraitBuilder.fieldReference(input3, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input3 -> + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference(input3, 0), + // orders.o_custkey + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 2)), + substraitBuilder.namedScan( + List.of("customer"), + List.of("c_custkey", "c_nationkey"), + List.of( + TypeCreator.REQUIRED.I64, + TypeCreator.REQUIRED.I64)))), + substraitBuilder.fieldReference(input2, 0)), + substraitBuilder.namedScan( + List.of("nation"), + List.of("n_nationkey", "n_name"), + List.of( + TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)))), + TypeCreator.NULLABLE.STRING)), + Remap.of(List.of(2, 3)), + substraitBuilder.namedScan( + List.of("orders"), + List.of("o_orderkey", "o_custkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + + final RelNode calciteRel = converter.convert(root); + + // most inner LogicalFilter has field reference with $cor0 correlation variable + // most outer LogicalProject has variablesSet containing $cor0 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(n_name0=[$1])\n" + + " LogicalFilter(condition=[IN($0, {\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.o_custkey)])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "})])\n" + + " LogicalTableScan(table=[[nation]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"n_name\" AS \"n_name0\"\n" + + "FROM \"nation\"\n" + + "WHERE \"n_nationkey\" IN (SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\")))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } + + @Test + void testSetPredicateOuterFieldReference() { + /* + * SELECT + * orders.o_orderkey, + * ( + * SELECT + * n_name + * FROM nation + * WHERE EXISTS + * ( + * SELECT + * customer.c_nationkey + * FROM customer + * WHERE + * customer.c_custkey = orders.o_custkey + * AND customer.c_nationkey = nation.n_nationkey + * ) + * ) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.fieldReference(input, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(2)), + substraitBuilder.filter( + input2 -> + substraitBuilder.exists( + substraitBuilder.project( + input3 -> + List.of(substraitBuilder.fieldReference(input3, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input3 -> + substraitBuilder.and( + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference( + input3, 0), + // orders.o_custkey + FieldReference + .newRootStructOuterReference( + 1, + TypeCreator.REQUIRED.I64, + 2)), + substraitBuilder.equal( + // customer.c_nationkey + substraitBuilder.fieldReference( + input3, 1), + // nation.n_nationkey + FieldReference + .newRootStructOuterReference( + 0, + TypeCreator.REQUIRED.I64, + 1))), + substraitBuilder.namedScan( + List.of("customer"), + List.of("c_custkey", "c_nationkey"), + List.of( + TypeCreator.REQUIRED.I64, + TypeCreator.REQUIRED.I64))))), + substraitBuilder.namedScan( + List.of("nation"), + List.of("n_nationkey", "n_name"), + List.of( + TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)))), + TypeCreator.NULLABLE.STRING)), + Remap.of(List.of(2, 3)), + substraitBuilder.namedScan( + List.of("orders"), + List.of("o_orderkey", "o_custkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + + final RelNode calciteRel = converter.convert(root); + + // most inner LogicalFilter has field references with $cor0 and $cor1 correlation variables + // most outer LogicalProject has variablesSet containing $cor0 correlation variable + // most outer LogicalFilter has variablesSet containing $cor1 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(n_name0=[$1])\n" + + " LogicalFilter(condition=[EXISTS({\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[AND(=($0, $cor0.o_custkey), =($1, $cor1.n_nationkey))])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "})], variablesSet=[[$cor1]])\n" + + " LogicalTableScan(table=[[nation]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"n_name\" AS \"n_name0\"\n" + + "FROM \"nation\"\n" + + "WHERE EXISTS (SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\" AND \"c_nationkey\" = \"nation\".\"n_nationkey\")))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } +} From aa4e998de9082cc932920ee95dd52d1264ba33bb Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Mon, 28 Jul 2025 10:33:57 +0200 Subject: [PATCH 2/4] chore(isthmus): apply feedback Signed-off-by: Niels Pardon --- .../io/substrait/dsl/SubstraitBuilder.java | 4 +- .../isthmus/SubstraitRelNodeConverter.java | 81 +++++++++++++++--- .../expression/ExpressionRexConverter.java | 31 +++---- .../expression/SubqueryConversionTest.java | 83 +++++++------------ 4 files changed, 110 insertions(+), 89 deletions(-) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 2bd4e8c5f..4e8e428f7 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -720,9 +720,9 @@ public Expression scalarSubquery(Rel input, Type type) { return Expression.ScalarSubquery.builder().input(input).type(type).build(); } - public Expression exists(Rel project) { + public Expression exists(Rel rel) { return Expression.SetPredicate.builder() - .tuples(project) + .tuples(rel) .predicateOp(PredicateOp.PREDICATE_OP_EXISTS) .build(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index e4279794b..90c20f11d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -3,6 +3,9 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Range; +import com.google.common.collect.RangeMap; +import com.google.common.collect.TreeRangeMap; import io.substrait.expression.Expression; import io.substrait.expression.Expression.SortDirection; import io.substrait.expression.FunctionArg; @@ -159,7 +162,7 @@ public static RelNode convert( @Override public RelNode visit(Filter filter, Context context) throws RuntimeException { RelNode input = filter.getInput().accept(this, context); - context.pushParentRelNodes(input); + context.pushOuterRowType(input.getRowType()); RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context); RelNode node = relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build(); @@ -189,7 +192,7 @@ public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeExcepti @Override public RelNode visit(Project project, Context context) throws RuntimeException { RelNode child = project.getInput().accept(this, context); - context.pushParentRelNodes(child); + context.pushOuterRowType(child.getRowType()); Stream directOutputs = IntStream.range(0, child.getRowType().getFieldCount()) @@ -224,7 +227,7 @@ public RelNode visit(Cross cross, Context context) throws RuntimeException { public RelNode visit(Join join, Context context) throws RuntimeException { RelNode left = join.getLeft().accept(this, context); RelNode right = join.getRight().accept(this, context); - context.pushParentRelNodes(left, right); + context.pushOuterRowType(left.getRowType(), right.getRowType()); RexNode condition = join.getCondition() .map(c -> c.accept(expressionRexConverter, context)) @@ -646,48 +649,100 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) { return relBuilder.push(relNode).project(rexList).build(); } + /** A shared context for the Substrait to RelNode conversion. */ public static class Context implements VisitationContext { - protected final Stack parentRelations = new Stack<>(); + protected final Stack> outerRowTypes = new Stack<>(); protected final Stack> correlationIds = new Stack<>(); private int subqueryDepth; + /** + * Creates a new {@link Context} instance. + * + * @return the new {@link Context} instance + */ public static Context newContext() { return new Context(); } - public void pushParentRelNodes(final RelNode... inputs) { - parentRelations.push(inputs); + /** + * Adds the outer row types to the top of the stack of outer row types. + * + *

Row types are stored as a {@link RangeMap} with field indices as keys and the {@link + * RelDataType} row type containing the field at the field index by continuously numbering the + * field indices from 0 across all provided row types in the order the row types are passed as + * arguments. + * + * @param inputs the row types to add + */ + public void pushOuterRowType(final RelDataType... inputs) { + final RangeMap fieldRangeMap = TreeRangeMap.create(); + int begin = 0; + for (final RelDataType parent : inputs) { + final int end = begin + parent.getFieldCount(); + final Range range = Range.closedOpen(begin, end); + fieldRangeMap.put(range, parent); + begin = end; + } + + outerRowTypes.push(fieldRangeMap); this.correlationIds.push(new HashSet<>()); } public void popParentRelNodes() { - parentRelations.pop(); - } - - public RelNode[] getParentRelation(final Integer stepsOut) { - return this.parentRelations.get(subqueryDepth - stepsOut); - } - + outerRowTypes.pop(); + } + + /** + * Returns the outer row type {@link RangeMap} walking up the given steps from the current + * subquery depth. + * + * @param stepsOut number of steps to walk up from the current subquery depth + * @return {@link RangeMap} with field indices as keys and the {@link RelDataType} row type + * containing the field at the field index + */ + public RangeMap getOuterRowTypeRangeMap(final Integer stepsOut) { + return this.outerRowTypes.get(subqueryDepth - stepsOut); + } + + /** + * Removes the correlation ids at the top of the stack. + * + * @return the correlation ids removed from the top of the stack + */ public java.util.Set popCorrelationIds() { return correlationIds.pop(); } + /** + * Adds a {@link CorrelationId} to the subquery depth walking up the given steps from the + * current subquery depth. + * + * @param stepsOut number of steps to walk up from the current subquery depth + * @param correlationId the {@link CorrelationId} to add + */ public void addCorrelationId(final int stepsOut, final CorrelationId correlationId) { final int index = subqueryDepth - stepsOut; this.correlationIds.get(index).add(correlationId); } + /** Increments the current subquery depth. */ public void incrementSubqueryDepth() { this.subqueryDepth++; } + /** Decrements the current subquery depth. */ public void decrementSubqueryDepth() { this.subqueryDepth--; } } + /** + * Returns the {@link RelBuilder} of this converter. + * + * @return the {@link RelBuilder} + */ public RelBuilder getRelBuilder() { return relBuilder; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index ea69fba40..2b8052889 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -3,7 +3,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Range; import com.google.common.collect.RangeMap; -import com.google.common.collect.TreeRangeMap; import io.substrait.expression.AbstractExpressionVisitor; import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; @@ -599,7 +598,7 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti final RexInputRef rexInputRef; if (segment instanceof FieldReference.StructField) { - FieldReference.StructField field = (FieldReference.StructField) segment; + final FieldReference.StructField field = (FieldReference.StructField) segment; rexInputRef = new RexInputRef(field.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); } else { @@ -612,29 +611,19 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti final RexNode rexInputRef; if (segment instanceof FieldReference.StructField) { - FieldReference.StructField field = (FieldReference.StructField) segment; - - final RelNode[] parents = context.getParentRelation(expr.outerReferenceStepsOut().get()); - final RangeMap fieldRangeMap = TreeRangeMap.create(); - - int begin = 0; - int fieldOffset = field.offset(); - for (final RelNode parent : parents) { - final int end = begin + parent.getRowType().getFieldCount(); - final Range range = Range.closedOpen(begin, end); - fieldRangeMap.put(range, parent); - if (range.contains(field.offset())) { - fieldOffset = fieldOffset - range.lowerEndpoint(); - } - begin = end; - } + final FieldReference.StructField field = (FieldReference.StructField) segment; + + final RangeMap fieldRangeMap = + context.getOuterRowTypeRangeMap(expr.outerReferenceStepsOut().get()); + final Range range = fieldRangeMap.getEntry(field.offset()).getKey(); + final int fieldOffset = field.offset() - range.lowerEndpoint(); - CorrelationId correlationId = relNodeConverter.getRelBuilder().getCluster().createCorrel(); + final CorrelationId correlationId = + relNodeConverter.getRelBuilder().getCluster().createCorrel(); context.addCorrelationId(expr.outerReferenceStepsOut().get(), correlationId); rexInputRef = rexBuilder.makeFieldAccess( - rexBuilder.makeCorrel( - fieldRangeMap.get(field.offset()).getRowType(), correlationId), + rexBuilder.makeCorrel(fieldRangeMap.get(field.offset()), correlationId), fieldOffset); } else { throw new IllegalArgumentException("Unhandled type: " + segment); diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java index 0fc5a4616..ea82e0a8d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java @@ -6,6 +6,7 @@ import io.substrait.isthmus.PlanTestBase; import io.substrait.isthmus.SubstraitToCalcite; import io.substrait.isthmus.sql.SubstraitSqlDialect; +import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; import io.substrait.relation.Rel.Remap; import io.substrait.type.TypeCreator; @@ -16,6 +17,24 @@ class SubqueryConversionTest extends PlanTestBase { protected final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + private final Rel customerTableScan = + substraitBuilder.namedScan( + List.of("customer"), + List.of("c_custkey", "c_nationkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); + + private final NamedScan orderTableScan = + substraitBuilder.namedScan( + List.of("orders"), + List.of("o_orderkey", "o_custkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); + + private final NamedScan nationTableScan = + substraitBuilder.namedScan( + List.of("nation"), + List.of("n_nationkey", "n_name"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)); + @Test void testOuterFieldReferenceOneStep() { /* @@ -44,16 +63,10 @@ void testOuterFieldReferenceOneStep() { // orders.o_custkey FieldReference.newRootStructOuterReference( 1, TypeCreator.REQUIRED.I64, 1)), - substraitBuilder.namedScan( - List.of("customer"), - List.of("c_custkey", "c_nationkey"), - List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)))), + customerTableScan)), TypeCreator.NULLABLE.I64)), Remap.of(List.of(2, 3)), - substraitBuilder.namedScan( - List.of("orders"), - List.of("o_orderkey", "o_custkey"), - List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + orderTableScan); final RelNode calciteRel = converter.convert(root); @@ -127,24 +140,12 @@ void testOuterFieldReferenceTwoSteps() { 1, TypeCreator.REQUIRED.I64, 2)), - substraitBuilder.namedScan( - List.of("customer"), - List.of("c_custkey", "c_nationkey"), - List.of( - TypeCreator.REQUIRED.I64, - TypeCreator.REQUIRED.I64)))), + customerTableScan)), TypeCreator.NULLABLE.I64)), - substraitBuilder.namedScan( - List.of("nation"), - List.of("n_nationkey", "n_name"), - List.of( - TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)))), + nationTableScan)), TypeCreator.NULLABLE.STRING)), Remap.of(List.of(2, 3)), - substraitBuilder.namedScan( - List.of("orders"), - List.of("o_orderkey", "o_custkey"), - List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + orderTableScan); final RelNode calciteRel = converter.convert(root); @@ -217,24 +218,12 @@ void testInPredicateOuterFieldReference() { // orders.o_custkey FieldReference.newRootStructOuterReference( 1, TypeCreator.REQUIRED.I64, 2)), - substraitBuilder.namedScan( - List.of("customer"), - List.of("c_custkey", "c_nationkey"), - List.of( - TypeCreator.REQUIRED.I64, - TypeCreator.REQUIRED.I64)))), + customerTableScan)), substraitBuilder.fieldReference(input2, 0)), - substraitBuilder.namedScan( - List.of("nation"), - List.of("n_nationkey", "n_name"), - List.of( - TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)))), + nationTableScan)), TypeCreator.NULLABLE.STRING)), Remap.of(List.of(2, 3)), - substraitBuilder.namedScan( - List.of("orders"), - List.of("o_orderkey", "o_custkey"), - List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + orderTableScan); final RelNode calciteRel = converter.convert(root); @@ -323,23 +312,11 @@ void testSetPredicateOuterFieldReference() { 0, TypeCreator.REQUIRED.I64, 1))), - substraitBuilder.namedScan( - List.of("customer"), - List.of("c_custkey", "c_nationkey"), - List.of( - TypeCreator.REQUIRED.I64, - TypeCreator.REQUIRED.I64))))), - substraitBuilder.namedScan( - List.of("nation"), - List.of("n_nationkey", "n_name"), - List.of( - TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)))), + customerTableScan))), + nationTableScan)), TypeCreator.NULLABLE.STRING)), Remap.of(List.of(2, 3)), - substraitBuilder.namedScan( - List.of("orders"), - List.of("o_orderkey", "o_custkey"), - List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64))); + orderTableScan); final RelNode calciteRel = converter.convert(root); From e1f55732b17da5b1baa51778ccbed21535f79734 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Fri, 1 Aug 2025 06:14:03 +0200 Subject: [PATCH 3/4] chore: rename popParentRelNodes to popOuterRowType Signed-off-by: Niels Pardon --- .../io/substrait/isthmus/SubstraitRelNodeConverter.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 90c20f11d..801110c5f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -166,7 +166,7 @@ public RelNode visit(Filter filter, Context context) throws RuntimeException { RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context); RelNode node = relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build(); - context.popParentRelNodes(); + context.popOuterRowType(); return applyRemap(node, filter.getRemap()); } @@ -209,7 +209,7 @@ public RelNode visit(Project project, Context context) throws RuntimeException { .push(child) .project(rexExprs, List.of(), false, context.popCorrelationIds()) .build(); - context.popParentRelNodes(); + context.popOuterRowType(); return applyRemap(node, project.getRemap()); } @@ -239,7 +239,7 @@ public RelNode visit(Join join, Context context) throws RuntimeException { .push(right) .join(joinType, condition, context.popCorrelationIds()) .build(); - context.popParentRelNodes(); + context.popOuterRowType(); return applyRemap(node, join.getRemap()); } @@ -690,7 +690,7 @@ public void pushOuterRowType(final RelDataType... inputs) { this.correlationIds.push(new HashSet<>()); } - public void popParentRelNodes() { + public void popOuterRowType() { outerRowTypes.pop(); } From bc5f81c48eefa14188d932610bca850fe8cde1c5 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Fri, 1 Aug 2025 06:26:45 +0200 Subject: [PATCH 4/4] chore: include TPC-DS query 8 Signed-off-by: Niels Pardon --- isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java index 64c65ca74..9be238bc3 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java @@ -14,7 +14,7 @@ /** TPC-DS test to convert SQL to Substrait and then convert those plans back to SQL. */ public class TpcdsQueryTest extends PlanTestBase { private static final Set toSubstraitExclusions = Set.of(9, 27, 36, 70, 86); - private static final Set fromSubstraitExclusions = Set.of(1, 8, 30, 67, 81); + private static final Set fromSubstraitExclusions = Set.of(1, 30, 67, 81); static IntStream testCases() { return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));