Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.Expression.IfClause;
import io.substrait.expression.Expression.IfThen;
import io.substrait.expression.Expression.PredicateOp;
import io.substrait.expression.Expression.SingleOrList;
import io.substrait.expression.Expression.Switch;
import io.substrait.expression.Expression.SwitchClause;
Expand Down Expand Up @@ -644,6 +645,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig
DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
}

public Expression.ScalarFunctionInvocation and(Expression... args) {
// If any arg is nullable, the output of and is potentially nullable
// For example: false and null = null
boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "and:bool", outputType, args);
}

public Expression.ScalarFunctionInvocation or(Expression... args) {
// If any arg is nullable, the output of or is potentially nullable
// For example: false or null = null
Expand Down Expand Up @@ -706,4 +715,15 @@ public Plan plan(Plan.Root root) {
public Rel.Remap remap(Integer... fields) {
return Rel.Remap.of(Arrays.asList(fields));
}

public Expression scalarSubquery(Rel input, Type type) {
return Expression.ScalarSubquery.builder().input(input).type(type).build();
}

public Expression exists(Rel rel) {
return Expression.SetPredicate.builder()
.tuples(rel)
.predicateOp(PredicateOp.PREDICATE_OP_EXISTS)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}

public boolean isSimpleRootReference() {
return segments().size() == 1 && !inputExpression().isPresent();
return segments().size() == 1
&& !inputExpression().isPresent()
&& !outerReferenceStepsOut().isPresent();
}

public boolean isOuterReference() {
return outerReferenceStepsOut().orElse(0) > 0;
}

public FieldReference dereferenceStruct(int index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import com.google.common.collect.TreeRangeMap;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.SortDirection;
import io.substrait.expression.FunctionArg;
Expand Down Expand Up @@ -34,9 +37,11 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Stack;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand All @@ -49,6 +54,7 @@
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.logical.LogicalTableModify;
Expand Down Expand Up @@ -156,8 +162,11 @@ public static RelNode convert(
@Override
public RelNode visit(Filter filter, Context context) throws RuntimeException {
RelNode input = filter.getInput().accept(this, context);
context.pushOuterRowType(input.getRowType());
RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context);
RelNode node = relBuilder.push(input).filter(filterCondition).build();
RelNode node =
relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build();
context.popOuterRowType();
return applyRemap(node, filter.getRemap());
}

Expand All @@ -183,6 +192,8 @@ public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeExcepti
@Override
public RelNode visit(Project project, Context context) throws RuntimeException {
RelNode child = project.getInput().accept(this, context);
context.pushOuterRowType(child.getRowType());

Stream<RexNode> directOutputs =
IntStream.range(0, child.getRowType().getFieldCount())
.mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex));
Expand All @@ -193,7 +204,12 @@ public RelNode visit(Project project, Context context) throws RuntimeException {
List<RexNode> rexExprs =
Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList());

RelNode node = relBuilder.push(child).project(rexExprs).build();
RelNode node =
relBuilder
.push(child)
.project(rexExprs, List.of(), false, context.popCorrelationIds())
.build();
context.popOuterRowType();
return applyRemap(node, project.getRemap());
}

Expand All @@ -211,12 +227,19 @@ public RelNode visit(Cross cross, Context context) throws RuntimeException {
public RelNode visit(Join join, Context context) throws RuntimeException {
RelNode left = join.getLeft().accept(this, context);
RelNode right = join.getRight().accept(this, context);
context.pushOuterRowType(left.getRowType(), right.getRowType());
RexNode condition =
join.getCondition()
.map(c -> c.accept(expressionRexConverter, context))
.orElse(relBuilder.literal(true));
JoinRelType joinType = asJoinRelType(join);
RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build();
RelNode node =
relBuilder
.push(left)
.push(right)
.join(joinType, condition, context.popCorrelationIds())
.build();
context.popOuterRowType();
return applyRemap(node, join.getRemap());
}

Expand Down Expand Up @@ -626,9 +649,101 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) {
return relBuilder.push(relNode).project(rexList).build();
}

/** A shared context for the Substrait to RelNode conversion. */
public static class Context implements VisitationContext {
protected final Stack<RangeMap<Integer, RelDataType>> outerRowTypes = new Stack<>();

protected final Stack<java.util.Set<CorrelationId>> correlationIds = new Stack<>();

private int subqueryDepth;

/**
* Creates a new {@link Context} instance.
*
* @return the new {@link Context} instance
*/
public static Context newContext() {
return new Context();
}

/**
* Adds the outer row types to the top of the stack of outer row types.
*
* <p>Row types are stored as a {@link RangeMap} with field indices as keys and the {@link
* RelDataType} row type containing the field at the field index by continuously numbering the
* field indices from 0 across all provided row types in the order the row types are passed as
* arguments.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RangeMap is super nice for this ✨

*
* @param inputs the row types to add
*/
public void pushOuterRowType(final RelDataType... inputs) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushOuterRowType is a good name for this. We should update the popParentRelNodes to match.

final RangeMap<Integer, RelDataType> fieldRangeMap = TreeRangeMap.create();
int begin = 0;
for (final RelDataType parent : inputs) {
final int end = begin + parent.getFieldCount();
final Range<Integer> range = Range.closedOpen(begin, end);
fieldRangeMap.put(range, parent);
begin = end;
}

outerRowTypes.push(fieldRangeMap);
this.correlationIds.push(new HashSet<>());
}

public void popOuterRowType() {
outerRowTypes.pop();
}

/**
* Returns the outer row type {@link RangeMap} walking up the given steps from the current
* subquery depth.
*
* @param stepsOut number of steps to walk up from the current subquery depth
* @return {@link RangeMap} with field indices as keys and the {@link RelDataType} row type
* containing the field at the field index
*/
public RangeMap<Integer, RelDataType> getOuterRowTypeRangeMap(final Integer stepsOut) {
return this.outerRowTypes.get(subqueryDepth - stepsOut);
}

/**
* Removes the correlation ids at the top of the stack.
*
* @return the correlation ids removed from the top of the stack
*/
public java.util.Set<CorrelationId> popCorrelationIds() {
return correlationIds.pop();
}

/**
* Adds a {@link CorrelationId} to the subquery depth walking up the given steps from the
* current subquery depth.
*
* @param stepsOut number of steps to walk up from the current subquery depth
* @param correlationId the {@link CorrelationId} to add
*/
public void addCorrelationId(final int stepsOut, final CorrelationId correlationId) {
final int index = subqueryDepth - stepsOut;
this.correlationIds.get(index).add(correlationId);
}

/** Increments the current subquery depth. */
public void incrementSubqueryDepth() {
this.subqueryDepth++;
}

/** Decrements the current subquery depth. */
public void decrementSubqueryDepth() {
this.subqueryDepth--;
}
}

/**
* Returns the {@link RelBuilder} of this converter.
*
* @return the {@link RelBuilder}
*/
public RelBuilder getRelBuilder() {
return relBuilder;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.substrait.isthmus.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import io.substrait.expression.AbstractExpressionVisitor;
import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression;
Expand Down Expand Up @@ -33,6 +35,7 @@
import java.util.stream.Stream;
import org.apache.calcite.avatica.util.ByteString;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
Expand Down Expand Up @@ -513,7 +516,9 @@ private boolean isDistinct(Expression.WindowFunctionInvocation expr) {
public RexNode visit(Expression.InPredicate expr, Context context) throws RuntimeException {
List<RexNode> needles =
expr.needles().stream().map(e -> e.accept(this, context)).collect(Collectors.toList());
context.incrementSubqueryDepth();
RelNode rel = expr.haystack().accept(relNodeConverter, context);
Comment thread
nielspardon marked this conversation as resolved.
context.decrementSubqueryDepth();
return RexSubQuery.in(rel, ImmutableList.copyOf(needles));
}

Expand Down Expand Up @@ -589,13 +594,37 @@ public RexNode visit(Expression.Cast expr, Context context) throws RuntimeExcept
@Override
public RexNode visit(FieldReference expr, Context context) throws RuntimeException {
if (expr.isSimpleRootReference()) {
ReferenceSegment segment = expr.segments().get(0);
final ReferenceSegment segment = expr.segments().get(0);

RexInputRef rexInputRef;
final RexInputRef rexInputRef;
if (segment instanceof FieldReference.StructField) {
FieldReference.StructField f = (FieldReference.StructField) segment;
final FieldReference.StructField field = (FieldReference.StructField) segment;
rexInputRef =
new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType()));
new RexInputRef(field.offset(), typeConverter.toCalcite(typeFactory, expr.getType()));
} else {
throw new IllegalArgumentException("Unhandled type: " + segment);
}

return rexInputRef;
} else if (expr.isOuterReference()) {
final ReferenceSegment segment = expr.segments().get(0);

final RexNode rexInputRef;
if (segment instanceof FieldReference.StructField) {
final FieldReference.StructField field = (FieldReference.StructField) segment;

final RangeMap<Integer, RelDataType> fieldRangeMap =
context.getOuterRowTypeRangeMap(expr.outerReferenceStepsOut().get());
final Range<Integer> range = fieldRangeMap.getEntry(field.offset()).getKey();
final int fieldOffset = field.offset() - range.lowerEndpoint();

final CorrelationId correlationId =
relNodeConverter.getRelBuilder().getCluster().createCorrel();
context.addCorrelationId(expr.outerReferenceStepsOut().get(), correlationId);
rexInputRef =
rexBuilder.makeFieldAccess(
rexBuilder.makeCorrel(fieldRangeMap.get(field.offset()), correlationId),
fieldOffset);
} else {
throw new IllegalArgumentException("Unhandled type: " + segment);
}
Expand Down Expand Up @@ -646,13 +675,17 @@ public RexNode visitEnumArg(

@Override
public RexNode visit(ScalarSubquery expr, Context context) throws RuntimeException {
context.incrementSubqueryDepth();
RelNode inputRelnode = expr.input().accept(relNodeConverter, context);
context.decrementSubqueryDepth();
return RexSubQuery.scalar(inputRelnode);
}

@Override
public RexNode visit(SetPredicate expr, Context context) throws RuntimeException {
context.incrementSubqueryDepth();
RelNode inputRelnode = expr.tuples().accept(relNodeConverter, context);
Comment thread
vbarua marked this conversation as resolved.
context.decrementSubqueryDepth();
switch (expr.predicateOp()) {
case PREDICATE_OP_EXISTS:
return RexSubQuery.exists(inputRelnode);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.proto.Plan;
import java.io.IOException;
Expand All @@ -13,7 +14,7 @@
/** TPC-DS test to convert SQL to Substrait and then convert those plans back to SQL. */
public class TpcdsQueryTest extends PlanTestBase {
private static final Set<Integer> toSubstraitExclusions = Set.of(9, 27, 36, 70, 86);
private static final Set<Integer> fromSubstraitExclusions = Set.of(6, 8, 67);
private static final Set<Integer> fromSubstraitExclusions = Set.of(1, 30, 67, 81);

static IntStream testCases() {
return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));
Expand All @@ -32,6 +33,8 @@ public void testQuery(int query) throws IOException {

if (!fromSubstraitExclusions.contains(query)) {
assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL");
} else {
assertThrows(Throwable.class, () -> toSql(plan), "Substrait to SQL");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

import io.substrait.proto.Plan;
import java.io.IOException;
import java.util.Set;
import java.util.stream.IntStream;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

/** TPC-H test to convert SQL to Substrait and then convert those plans back to SQL. */
public class TpchQueryTest extends PlanTestBase {
private static final Set<Integer> fromSubstraitExclusions = Set.of(17);

static IntStream testCases() {
return IntStream.rangeClosed(1, 22);
}
Expand All @@ -29,9 +26,7 @@ public void testQuery(int query) throws IOException {

Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait");

if (!fromSubstraitExclusions.contains(query)) {
assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL");
}
assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL");
}

private Plan toSubstraitPlan(String sql) throws SqlParseException {
Expand Down
Loading
Loading