diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 9de3a0faf..4e8e428f7 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -6,6 +6,7 @@ import io.substrait.expression.Expression.FailureBehavior; import io.substrait.expression.Expression.IfClause; import io.substrait.expression.Expression.IfThen; +import io.substrait.expression.Expression.PredicateOp; import io.substrait.expression.Expression.SingleOrList; import io.substrait.expression.Expression.Switch; import io.substrait.expression.Expression.SwitchClause; @@ -644,6 +645,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right); } + public Expression.ScalarFunctionInvocation and(Expression... args) { + // If any arg is nullable, the output of and is potentially nullable + // For example: false and null = null + boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); + Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; + return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "and:bool", outputType, args); + } + public Expression.ScalarFunctionInvocation or(Expression... args) { // If any arg is nullable, the output of or is potentially nullable // For example: false or null = null @@ -706,4 +715,15 @@ public Plan plan(Plan.Root root) { public Rel.Remap remap(Integer... fields) { return Rel.Remap.of(Arrays.asList(fields)); } + + public Expression scalarSubquery(Rel input, Type type) { + return Expression.ScalarSubquery.builder().input(input).type(type).build(); + } + + public Expression exists(Rel rel) { + return Expression.SetPredicate.builder() + .tuples(rel) + .predicateOp(PredicateOp.PREDICATE_OP_EXISTS) + .build(); + } } diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 192dc2578..722699153 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -36,7 +36,13 @@ public R accept( } public boolean isSimpleRootReference() { - return segments().size() == 1 && !inputExpression().isPresent(); + return segments().size() == 1 + && !inputExpression().isPresent() + && !outerReferenceStepsOut().isPresent(); + } + + public boolean isOuterReference() { + return outerReferenceStepsOut().orElse(0) > 0; } public FieldReference dereferenceStruct(int index) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 1e69e4d6d..801110c5f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -3,6 +3,9 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Range; +import com.google.common.collect.RangeMap; +import com.google.common.collect.TreeRangeMap; import io.substrait.expression.Expression; import io.substrait.expression.Expression.SortDirection; import io.substrait.expression.FunctionArg; @@ -34,9 +37,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.OptionalLong; +import java.util.Stack; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -49,6 +54,7 @@ import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.logical.LogicalTableModify; @@ -156,8 +162,11 @@ public static RelNode convert( @Override public RelNode visit(Filter filter, Context context) throws RuntimeException { RelNode input = filter.getInput().accept(this, context); + context.pushOuterRowType(input.getRowType()); RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context); - RelNode node = relBuilder.push(input).filter(filterCondition).build(); + RelNode node = + relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build(); + context.popOuterRowType(); return applyRemap(node, filter.getRemap()); } @@ -183,6 +192,8 @@ public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeExcepti @Override public RelNode visit(Project project, Context context) throws RuntimeException { RelNode child = project.getInput().accept(this, context); + context.pushOuterRowType(child.getRowType()); + Stream directOutputs = IntStream.range(0, child.getRowType().getFieldCount()) .mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex)); @@ -193,7 +204,12 @@ public RelNode visit(Project project, Context context) throws RuntimeException { List rexExprs = Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList()); - RelNode node = relBuilder.push(child).project(rexExprs).build(); + RelNode node = + relBuilder + .push(child) + .project(rexExprs, List.of(), false, context.popCorrelationIds()) + .build(); + context.popOuterRowType(); return applyRemap(node, project.getRemap()); } @@ -211,12 +227,19 @@ public RelNode visit(Cross cross, Context context) throws RuntimeException { public RelNode visit(Join join, Context context) throws RuntimeException { RelNode left = join.getLeft().accept(this, context); RelNode right = join.getRight().accept(this, context); + context.pushOuterRowType(left.getRowType(), right.getRowType()); RexNode condition = join.getCondition() .map(c -> c.accept(expressionRexConverter, context)) .orElse(relBuilder.literal(true)); JoinRelType joinType = asJoinRelType(join); - RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build(); + RelNode node = + relBuilder + .push(left) + .push(right) + .join(joinType, condition, context.popCorrelationIds()) + .build(); + context.popOuterRowType(); return applyRemap(node, join.getRemap()); } @@ -626,9 +649,101 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) { return relBuilder.push(relNode).project(rexList).build(); } + /** A shared context for the Substrait to RelNode conversion. */ public static class Context implements VisitationContext { + protected final Stack> outerRowTypes = new Stack<>(); + + protected final Stack> correlationIds = new Stack<>(); + + private int subqueryDepth; + + /** + * Creates a new {@link Context} instance. + * + * @return the new {@link Context} instance + */ public static Context newContext() { return new Context(); } + + /** + * Adds the outer row types to the top of the stack of outer row types. + * + *

Row types are stored as a {@link RangeMap} with field indices as keys and the {@link + * RelDataType} row type containing the field at the field index by continuously numbering the + * field indices from 0 across all provided row types in the order the row types are passed as + * arguments. + * + * @param inputs the row types to add + */ + public void pushOuterRowType(final RelDataType... inputs) { + final RangeMap fieldRangeMap = TreeRangeMap.create(); + int begin = 0; + for (final RelDataType parent : inputs) { + final int end = begin + parent.getFieldCount(); + final Range range = Range.closedOpen(begin, end); + fieldRangeMap.put(range, parent); + begin = end; + } + + outerRowTypes.push(fieldRangeMap); + this.correlationIds.push(new HashSet<>()); + } + + public void popOuterRowType() { + outerRowTypes.pop(); + } + + /** + * Returns the outer row type {@link RangeMap} walking up the given steps from the current + * subquery depth. + * + * @param stepsOut number of steps to walk up from the current subquery depth + * @return {@link RangeMap} with field indices as keys and the {@link RelDataType} row type + * containing the field at the field index + */ + public RangeMap getOuterRowTypeRangeMap(final Integer stepsOut) { + return this.outerRowTypes.get(subqueryDepth - stepsOut); + } + + /** + * Removes the correlation ids at the top of the stack. + * + * @return the correlation ids removed from the top of the stack + */ + public java.util.Set popCorrelationIds() { + return correlationIds.pop(); + } + + /** + * Adds a {@link CorrelationId} to the subquery depth walking up the given steps from the + * current subquery depth. + * + * @param stepsOut number of steps to walk up from the current subquery depth + * @param correlationId the {@link CorrelationId} to add + */ + public void addCorrelationId(final int stepsOut, final CorrelationId correlationId) { + final int index = subqueryDepth - stepsOut; + this.correlationIds.get(index).add(correlationId); + } + + /** Increments the current subquery depth. */ + public void incrementSubqueryDepth() { + this.subqueryDepth++; + } + + /** Decrements the current subquery depth. */ + public void decrementSubqueryDepth() { + this.subqueryDepth--; + } + } + + /** + * Returns the {@link RelBuilder} of this converter. + * + * @return the {@link RelBuilder} + */ + public RelBuilder getRelBuilder() { + return relBuilder; } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index a042db611..2b8052889 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -1,6 +1,8 @@ package io.substrait.isthmus.expression; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Range; +import com.google.common.collect.RangeMap; import io.substrait.expression.AbstractExpressionVisitor; import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; @@ -33,6 +35,7 @@ import java.util.stream.Stream; import org.apache.calcite.avatica.util.ByteString; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; @@ -513,7 +516,9 @@ private boolean isDistinct(Expression.WindowFunctionInvocation expr) { public RexNode visit(Expression.InPredicate expr, Context context) throws RuntimeException { List needles = expr.needles().stream().map(e -> e.accept(this, context)).collect(Collectors.toList()); + context.incrementSubqueryDepth(); RelNode rel = expr.haystack().accept(relNodeConverter, context); + context.decrementSubqueryDepth(); return RexSubQuery.in(rel, ImmutableList.copyOf(needles)); } @@ -589,13 +594,37 @@ public RexNode visit(Expression.Cast expr, Context context) throws RuntimeExcept @Override public RexNode visit(FieldReference expr, Context context) throws RuntimeException { if (expr.isSimpleRootReference()) { - ReferenceSegment segment = expr.segments().get(0); + final ReferenceSegment segment = expr.segments().get(0); - RexInputRef rexInputRef; + final RexInputRef rexInputRef; if (segment instanceof FieldReference.StructField) { - FieldReference.StructField f = (FieldReference.StructField) segment; + final FieldReference.StructField field = (FieldReference.StructField) segment; rexInputRef = - new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); + new RexInputRef(field.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } + + return rexInputRef; + } else if (expr.isOuterReference()) { + final ReferenceSegment segment = expr.segments().get(0); + + final RexNode rexInputRef; + if (segment instanceof FieldReference.StructField) { + final FieldReference.StructField field = (FieldReference.StructField) segment; + + final RangeMap fieldRangeMap = + context.getOuterRowTypeRangeMap(expr.outerReferenceStepsOut().get()); + final Range range = fieldRangeMap.getEntry(field.offset()).getKey(); + final int fieldOffset = field.offset() - range.lowerEndpoint(); + + final CorrelationId correlationId = + relNodeConverter.getRelBuilder().getCluster().createCorrel(); + context.addCorrelationId(expr.outerReferenceStepsOut().get(), correlationId); + rexInputRef = + rexBuilder.makeFieldAccess( + rexBuilder.makeCorrel(fieldRangeMap.get(field.offset()), correlationId), + fieldOffset); } else { throw new IllegalArgumentException("Unhandled type: " + segment); } @@ -646,13 +675,17 @@ public RexNode visitEnumArg( @Override public RexNode visit(ScalarSubquery expr, Context context) throws RuntimeException { + context.incrementSubqueryDepth(); RelNode inputRelnode = expr.input().accept(relNodeConverter, context); + context.decrementSubqueryDepth(); return RexSubQuery.scalar(inputRelnode); } @Override public RexNode visit(SetPredicate expr, Context context) throws RuntimeException { + context.incrementSubqueryDepth(); RelNode inputRelnode = expr.tuples().accept(relNodeConverter, context); + context.decrementSubqueryDepth(); switch (expr.predicateOp()) { case PREDICATE_OP_EXISTS: return RexSubQuery.exists(inputRelnode); diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java index cbba343fd..9be238bc3 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java @@ -1,6 +1,7 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.proto.Plan; import java.io.IOException; @@ -13,7 +14,7 @@ /** TPC-DS test to convert SQL to Substrait and then convert those plans back to SQL. */ public class TpcdsQueryTest extends PlanTestBase { private static final Set toSubstraitExclusions = Set.of(9, 27, 36, 70, 86); - private static final Set fromSubstraitExclusions = Set.of(6, 8, 67); + private static final Set fromSubstraitExclusions = Set.of(1, 30, 67, 81); static IntStream testCases() { return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n)); @@ -32,6 +33,8 @@ public void testQuery(int query) throws IOException { if (!fromSubstraitExclusions.contains(query)) { assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); + } else { + assertThrows(Throwable.class, () -> toSql(plan), "Substrait to SQL"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java index d52a682cf..0a8b0e8af 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java @@ -4,7 +4,6 @@ import io.substrait.proto.Plan; import java.io.IOException; -import java.util.Set; import java.util.stream.IntStream; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.params.ParameterizedTest; @@ -12,8 +11,6 @@ /** TPC-H test to convert SQL to Substrait and then convert those plans back to SQL. */ public class TpchQueryTest extends PlanTestBase { - private static final Set fromSubstraitExclusions = Set.of(17); - static IntStream testCases() { return IntStream.rangeClosed(1, 22); } @@ -29,9 +26,7 @@ public void testQuery(int query) throws IOException { Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait"); - if (!fromSubstraitExclusions.contains(query)) { - assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); - } + assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); } private Plan toSubstraitPlan(String sql) throws SqlParseException { diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java new file mode 100644 index 000000000..ea82e0a8d --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java @@ -0,0 +1,348 @@ +package io.substrait.isthmus.expression; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.expression.FieldReference; +import io.substrait.isthmus.PlanTestBase; +import io.substrait.isthmus.SubstraitToCalcite; +import io.substrait.isthmus.sql.SubstraitSqlDialect; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Rel; +import io.substrait.relation.Rel.Remap; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.apache.calcite.rel.RelNode; +import org.junit.jupiter.api.Test; + +class SubqueryConversionTest extends PlanTestBase { + protected final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + + private final Rel customerTableScan = + substraitBuilder.namedScan( + List.of("customer"), + List.of("c_custkey", "c_nationkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); + + private final NamedScan orderTableScan = + substraitBuilder.namedScan( + List.of("orders"), + List.of("o_orderkey", "o_custkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); + + private final NamedScan nationTableScan = + substraitBuilder.namedScan( + List.of("nation"), + List.of("n_nationkey", "n_name"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)); + + @Test + void testOuterFieldReferenceOneStep() { + /* + * SELECT + * orders.o_orderkey, + * (SELECT customer.c_nationkey FROM customer WHERE customer.c_custkey = orders.o_custkey) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + // orders.o_orderkey + substraitBuilder.fieldReference(input, 0), + // (SELECT customer.c_nationkey FROM customer WHERE customer.c_custkey = + // orders.o_custkey) + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input2 -> + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference(input2, 0), + // orders.o_custkey + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 1)), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2, 3)), + orderTableScan); + + final RelNode calciteRel = converter.convert(root); + + // LogicalFilter has field reference with $cor0 correlation variable + // outer LogicalProject has variablesSet containing $cor0 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.o_custkey)])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\"))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } + + @Test + void testOuterFieldReferenceTwoSteps() { + /* + * SELECT + * orders.o_orderkey, + * ( + * SELECT + * n_name + * FROM nation + * WHERE n_nationkey = + * ( + * SELECT + * customer.c_nationkey + * FROM customer + * WHERE + * customer.c_custkey = orders.o_custkey + * ) + * ) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.fieldReference(input, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(2)), + substraitBuilder.filter( + input2 -> + substraitBuilder.equal( + substraitBuilder.fieldReference(input2, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input3 -> + List.of( + substraitBuilder.fieldReference(input3, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input3 -> + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference( + input3, 0), + // orders.o_custkey + FieldReference + .newRootStructOuterReference( + 1, + TypeCreator.REQUIRED.I64, + 2)), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + nationTableScan)), + TypeCreator.NULLABLE.STRING)), + Remap.of(List.of(2, 3)), + orderTableScan); + + final RelNode calciteRel = converter.convert(root); + + // most inner LogicalFilter has field reference with $cor0 correlation variable + // most outer LogicalProject has variablesSet containing $cor0 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(n_name0=[$1])\n" + + " LogicalFilter(condition=[=($0, $SCALAR_QUERY({\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.o_custkey)])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "}))])\n" + + " LogicalTableScan(table=[[nation]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"n_name\" AS \"n_name0\"\n" + + "FROM \"nation\"\n" + + "WHERE \"n_nationkey\" = (((SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\")))))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } + + @Test + void testInPredicateOuterFieldReference() { + /* + * SELECT + * orders.o_orderkey, + * ( + * SELECT + * n_name + * FROM nation + * WHERE n_nationkey IN + * ( + * SELECT + * customer.c_nationkey + * FROM customer + * WHERE + * customer.c_custkey = orders.o_custkey + * ) + * ) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.fieldReference(input, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(2)), + substraitBuilder.filter( + input2 -> + substraitBuilder.inPredicate( + substraitBuilder.project( + input3 -> + List.of(substraitBuilder.fieldReference(input3, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input3 -> + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference(input3, 0), + // orders.o_custkey + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 2)), + customerTableScan)), + substraitBuilder.fieldReference(input2, 0)), + nationTableScan)), + TypeCreator.NULLABLE.STRING)), + Remap.of(List.of(2, 3)), + orderTableScan); + + final RelNode calciteRel = converter.convert(root); + + // most inner LogicalFilter has field reference with $cor0 correlation variable + // most outer LogicalProject has variablesSet containing $cor0 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(n_name0=[$1])\n" + + " LogicalFilter(condition=[IN($0, {\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.o_custkey)])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "})])\n" + + " LogicalTableScan(table=[[nation]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"n_name\" AS \"n_name0\"\n" + + "FROM \"nation\"\n" + + "WHERE \"n_nationkey\" IN (SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\")))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } + + @Test + void testSetPredicateOuterFieldReference() { + /* + * SELECT + * orders.o_orderkey, + * ( + * SELECT + * n_name + * FROM nation + * WHERE EXISTS + * ( + * SELECT + * customer.c_nationkey + * FROM customer + * WHERE + * customer.c_custkey = orders.o_custkey + * AND customer.c_nationkey = nation.n_nationkey + * ) + * ) + * FROM orders + */ + final Rel root = + substraitBuilder.project( + input -> + List.of( + substraitBuilder.fieldReference(input, 0), + substraitBuilder.scalarSubquery( + substraitBuilder.project( + input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + Remap.of(List.of(2)), + substraitBuilder.filter( + input2 -> + substraitBuilder.exists( + substraitBuilder.project( + input3 -> + List.of(substraitBuilder.fieldReference(input3, 1)), + Remap.of(List.of(1)), + substraitBuilder.filter( + input3 -> + substraitBuilder.and( + substraitBuilder.equal( + // customer.c_custkey + substraitBuilder.fieldReference( + input3, 0), + // orders.o_custkey + FieldReference + .newRootStructOuterReference( + 1, + TypeCreator.REQUIRED.I64, + 2)), + substraitBuilder.equal( + // customer.c_nationkey + substraitBuilder.fieldReference( + input3, 1), + // nation.n_nationkey + FieldReference + .newRootStructOuterReference( + 0, + TypeCreator.REQUIRED.I64, + 1))), + customerTableScan))), + nationTableScan)), + TypeCreator.NULLABLE.STRING)), + Remap.of(List.of(2, 3)), + orderTableScan); + + final RelNode calciteRel = converter.convert(root); + + // most inner LogicalFilter has field references with $cor0 and $cor1 correlation variables + // most outer LogicalProject has variablesSet containing $cor0 correlation variable + // most outer LogicalFilter has variablesSet containing $cor1 correlation variable + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(n_name0=[$1])\n" + + " LogicalFilter(condition=[EXISTS({\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[AND(=($0, $cor0.o_custkey), =($1, $cor1.n_nationkey))])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "})], variablesSet=[[$cor1]])\n" + + " LogicalTableScan(table=[[nation]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + + assertEquals( + "SELECT \"o_orderkey\" AS \"o_orderkey0\", (((SELECT \"n_name\" AS \"n_name0\"\n" + + "FROM \"nation\"\n" + + "WHERE EXISTS (SELECT \"c_nationkey\"\n" + + "FROM \"customer\"\n" + + "WHERE \"c_custkey\" = \"orders\".\"o_custkey\" AND \"c_nationkey\" = \"nation\".\"n_nationkey\")))) AS \"$f3\"\n" + + "FROM \"orders\"", + SubstraitSqlDialect.toSql(calciteRel).getSql()); + } +}