Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.substrait.expression.Expression.IfThen;
import io.substrait.expression.Expression.SwitchClause;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression.Cast;
import io.substrait.expression.ImmutableExpression.SingleOrList;
import io.substrait.expression.ImmutableExpression.Switch;
Expand Down Expand Up @@ -640,7 +641,7 @@ public Expression.ScalarFunctionInvocation or(Expression... args) {
}

public Expression.ScalarFunctionInvocation scalarFn(
String namespace, String key, Type outputType, Expression... args) {
String namespace, String key, Type outputType, FunctionArg... args) {
var declaration =
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return Expression.ScalarFunctionInvocation.builder()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package io.substrait.isthmus.expression;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import io.substrait.expression.EnumArg;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.extension.SimpleExtension.Argument;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import org.apache.calcite.avatica.util.TimeUnitRange;
Expand All @@ -25,16 +28,34 @@
*/
public class EnumConverter {

private static final BiMap<Class<? extends Enum>, ArgAnchor> calciteEnumMap = HashBiMap.create();
private static final Map<ArgAnchor, Class<? extends Enum<?>>> calciteEnumMap = new HashMap<>();
Comment thread
nielspardon marked this conversation as resolved.

static {
// deprecated {@link io.substrait.type.Type.Timestamp}
calciteEnumMap.put(
TimeUnitRange.class,
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ts", 0));
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ts", 0),
TimeUnitRange.class);
// deprecated {@link io.substrait.type.Type.TimestampTZ}
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_tstz_str", 0),
TimeUnitRange.class);

calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_pts", 0),
TimeUnitRange.class);
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ptstz_str", 0),
TimeUnitRange.class);
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_date", 0),
TimeUnitRange.class);
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_time", 0),
TimeUnitRange.class);
}

private static Optional<Enum> constructValue(
Class<? extends Enum> cls, Supplier<Optional<String>> option) {
private static Optional<Enum<?>> constructValue(
Class<? extends Enum<?>> cls, Supplier<Optional<String>> option) {
if (cls.isAssignableFrom(TimeUnitRange.class)) {
return option.get().map(TimeUnitRange::valueOf);
} else {
Expand All @@ -44,8 +65,9 @@ private static Optional<Enum> constructValue(

static Optional<RexLiteral> toRex(
RexBuilder rexBuilder, SimpleExtension.Function fnDef, int argIdx, EnumArg e) {
var aAnch = argAnchor(fnDef, argIdx);
var v = Optional.ofNullable(calciteEnumMap.inverse().getOrDefault(aAnch, null));
ArgAnchor aAnch = argAnchor(fnDef, argIdx);
Optional<Class<? extends Enum<?>>> v =
Optional.ofNullable(calciteEnumMap.getOrDefault(aAnch, null));

Supplier<Optional<String>> sOptionVal =
() -> {
Expand All @@ -66,11 +88,11 @@ private static Optional<SimpleExtension.EnumArgument> findEnumArg(
return Optional.empty();
} else {

var args = function.args();
List<Argument> args = function.args();
if (args.size() <= enumAnchor.argIdx) {
return Optional.empty();
}
var arg = args.get(enumAnchor.argIdx);
Argument arg = args.get(enumAnchor.argIdx);
if (arg instanceof SimpleExtension.EnumArgument ea) {
return Optional.of(ea);
} else {
Expand All @@ -79,17 +101,15 @@ private static Optional<SimpleExtension.EnumArgument> findEnumArg(
}
}

static Optional<EnumArg> fromRex(SimpleExtension.Function function, RexLiteral literal) {
static Optional<EnumArg> fromRex(
SimpleExtension.Function function, RexLiteral literal, int argIdx) {
return switch (literal.getType().getSqlTypeName()) {
case SYMBOL -> {
Object v = literal.getValue();
if (!literal.isNull() && (v instanceof Enum)) {
Enum value = (Enum) v;
Optional<ArgAnchor> enumAnchor =
Optional.ofNullable(calciteEnumMap.getOrDefault(value.getClass(), null));
yield enumAnchor
.flatMap(en -> findEnumArg(function, en))
.map(ea -> EnumArg.of(ea, value.name()));
Enum<?> value = (Enum<?>) v;
ArgAnchor enumAnchor = argAnchor(function, argIdx);
yield findEnumArg(function, enumAnchor).map(ea -> EnumArg.of(ea, value.name()));
} else {
yield Optional.empty();
}
Expand All @@ -98,8 +118,8 @@ static Optional<EnumArg> fromRex(SimpleExtension.Function function, RexLiteral l
};
}

static boolean canConvert(Enum value) {
return value != null && calciteEnumMap.containsKey(value.getClass());
static boolean canConvert(Enum<?> value) {
return value != null && calciteEnumMap.containsValue(value.getClass());
}

static boolean isEnumValue(RexNode value) {
Expand All @@ -116,6 +136,23 @@ public ArgAnchor(final SimpleExtension.FunctionAnchor fn, final int argIdx) {
this.fn = fn;
this.argIdx = argIdx;
}

@Override
public int hashCode() {
return Objects.hash(fn, argIdx);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof ArgAnchor)) {
return false;
}
ArgAnchor other = (ArgAnchor) obj;
return Objects.equals(fn, other.fn) && argIdx == other.argIdx;
}
}

private static ArgAnchor argAnchor(String fnNS, String fnSig, int argIdx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.Expression.PrecisionTimestampLiteral;
import io.substrait.expression.Expression.PrecisionTimestampTZLiteral;
import io.substrait.expression.Expression.ScalarSubquery;
import io.substrait.expression.Expression.SetPredicate;
import io.substrait.expression.Expression.SingleOrList;
import io.substrait.expression.Expression.Switch;
import io.substrait.expression.Expression.TimestampTZLiteral;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
Expand Down Expand Up @@ -209,20 +212,65 @@ public RexNode visit(Expression.DateLiteral expr) throws RuntimeException {

@Override
public RexNode visit(Expression.TimestampLiteral expr) throws RuntimeException {
// Expression.TimestampLiteral is microseconds
// Construct a TimeStampString :
// 1. Truncate microseconds to seconds
// 2. Get the fraction seconds in precision of nanoseconds.
// 3. Construct TimeStampString : seconds + fraction_seconds part.
long microSec = expr.value();
long seconds = TimeUnit.MICROSECONDS.toSeconds(microSec);
int fracSecondsInNano =
(int) (TimeUnit.MICROSECONDS.toNanos(microSec) - TimeUnit.SECONDS.toNanos(seconds));
return rexBuilder.makeLiteral(
getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType()));
}

TimestampString tsString =
TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
.withNanos(fracSecondsInNano);
return rexBuilder.makeLiteral(tsString, typeConverter.toCalcite(typeFactory, expr.getType()));
@Override
public RexNode visit(TimestampTZLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(
getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType()));
}

@Override
public RexNode visit(PrecisionTimestampLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(
getTimestampString(expr.value(), expr.precision()),
typeConverter.toCalcite(typeFactory, expr.getType()));
}

@Override
public RexNode visit(PrecisionTimestampTZLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(
getTimestampString(expr.value(), expr.precision()),
typeConverter.toCalcite(typeFactory, expr.getType()));
}

private TimestampString getTimestampString(long microSec) {
return getTimestampString(microSec, 6);
}

private TimestampString getTimestampString(long value, int precision) {
Comment thread
nielspardon marked this conversation as resolved.
switch (precision) {
case 0:
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(value));
case 3:
{
long seconds = TimeUnit.MILLISECONDS.toSeconds(value);
int fracSecondsInNano =
(int) (TimeUnit.MILLISECONDS.toNanos(value) - TimeUnit.SECONDS.toNanos(seconds));
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
.withNanos(fracSecondsInNano);
}
case 6:
{
long seconds = TimeUnit.MICROSECONDS.toSeconds(value);
int fracSecondsInNano =
(int) (TimeUnit.MICROSECONDS.toNanos(value) - TimeUnit.SECONDS.toNanos(seconds));
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
.withNanos(fracSecondsInNano);
}
case 9:
{
long seconds = TimeUnit.NANOSECONDS.toSeconds(value);
int fracSecondsInNano = (int) (value - TimeUnit.SECONDS.toNanos(seconds));
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
.withNanos(fracSecondsInNano);
}
default:
throw new UnsupportedOperationException(
String.format("Cannot handle PrecisionTimestamp with precision %d.", precision));
}
Comment thread
nielspardon marked this conversation as resolved.
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
Expand Down Expand Up @@ -124,7 +125,7 @@ public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type ou
operator ->
resolver.containsKey(operator)
&& resolver.get(operator).types().contains(outputTypeStr))
.collect(java.util.stream.Collectors.toList());
.collect(Collectors.toList());
// only one SqlOperator is possible
if (resolvedOperators.size() == 1) {
return Optional.of(resolvedOperators.get(0));
Expand Down Expand Up @@ -331,7 +332,7 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
}
return isOption ? List.of("req", "opt") : List.of(opType);
})
.collect(java.util.stream.Collectors.toList());
.collect(Collectors.toList());

return Utils.crossProduct(argTypeLists)
.map(typList -> typList.stream().collect(Collectors.joining("_")));
Expand All @@ -346,42 +347,43 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
* Once a FunctionVariant is resolved we can map the String Literal
* to a EnumArg.
*/
var operands =
call.getOperands().map(topLevelConverter).collect(java.util.stream.Collectors.toList());
var opTypes =
operands.stream().map(Expression::getType).collect(java.util.stream.Collectors.toList());
List<RexNode> operandsList = call.getOperands().collect(Collectors.toList());
List<Expression> operands =
call.getOperands().map(topLevelConverter).collect(Collectors.toList());
List<Type> opTypes = operands.stream().map(Expression::getType).collect(Collectors.toList());

var outputType = typeConverter.toSubstrait(call.getType());
Type outputType = typeConverter.toSubstrait(call.getType());

// try to do a direct match
var typeStrings =
List<String> typeStrings =
opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
var possibleKeys =
matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings);
Stream<String> possibleKeys =
matchKeys(call.getOperands().collect(Collectors.toList()), typeStrings);

var directMatchKey =
Optional<String> directMatchKey =
possibleKeys
.map(argList -> name + ":" + argList)
.filter(k -> directMap.containsKey(k))
.findFirst();

if (directMatchKey.isPresent()) {
var variant = directMap.get(directMatchKey.get());
F variant = directMap.get(directMatchKey.get());
variant.validateOutputType(operands, outputType);

List<FunctionArg> funcArgs =
Streams.zip(
call.getOperands(),
operands.stream(),
(r, o) -> {
IntStream.range(0, operandsList.size())
.mapToObj(
i -> {
RexNode r = operandsList.get(i);
Expression o = operands.get(i);
if (EnumConverter.isEnumValue(r)) {
return EnumConverter.fromRex(variant, (RexLiteral) r).orElseGet(() -> null);
return EnumConverter.fromRex(variant, (RexLiteral) r, i)
.orElseGet(() -> null);
} else {
return o;
}
})
.collect(java.util.stream.Collectors.toList());
var allArgsMapped = funcArgs.stream().filter(e -> e == null).findFirst().isEmpty();
.collect(Collectors.toList());
boolean allArgsMapped = funcArgs.stream().filter(e -> e == null).findFirst().isEmpty();
if (allArgsMapped) {
return Optional.of(generateBinding(call, variant, funcArgs, outputType));
} else {
Expand Down
Loading
Loading