Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ configurations[JavaPlugin.TEST_IMPLEMENTATION_CONFIGURATION_NAME].extendsFrom(sh

dependencies {
testImplementation(platform(libs.junit.bom))
testImplementation(libs.protobuf.java.util)

testImplementation(libs.junit.jupiter)
testRuntimeOnly(libs.junit.platform.launcher)

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ public Expression.WindowFunctionInvocation windowFn(
// Types

public Type.UserDefined userDefinedType(String namespace, String typeName) {
return Type.UserDefined.builder().uri(namespace).name(typeName).nullable(false).build();
return Type.UserDefined.builder().urn(namespace).name(typeName).nullable(false).build();
}

// Misc
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,13 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
abstract class UserDefinedLiteral implements Literal {
public abstract ByteString value();

public abstract String uri();
public abstract String urn();

public abstract String name();

@Override
public Type getType() {
return Type.withNullability(nullable()).userDefined(uri(), name());
return Type.withNullability(nullable()).userDefined(urn(), name());
}

public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ public static Expression.StructLiteral struct(
}

public static Expression.UserDefinedLiteral userDefinedLiteral(
boolean nullable, String uri, String name, Any value) {
boolean nullable, String urn, String name, Any value) {
return Expression.UserDefinedLiteral.builder()
.nullable(nullable)
.uri(uri)
.urn(urn)
.name(name)
.value(value.toByteString())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ public Expression visit(
public Expression visit(
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name()));
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
SimpleExtension.Type type =
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
return ExpressionCreator.userDefinedLiteral(
literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue());
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
}
default:
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ public ProtoExtendedExpressionConverter() {
}

public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) {
if (extensionCollection == null) {
throw new IllegalArgumentException("ExtensionCollection is required");
}
this.extensionCollection = extensionCollection;
}

public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) {
// fill in simple extension information through a discovery in the current proto-extended
// expression
ExtensionLookup functionLookup =
ImmutableExtensionLookup.builder().from(extendedExpression).build();
ImmutableExtensionLookup.builder(extensionCollection).from(extendedExpression).build();

NamedStruct baseSchemaProto = extendedExpression.getBaseSchema();

Expand Down
65 changes: 65 additions & 0 deletions core/src/main/java/io/substrait/extension/BidiMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package io.substrait.extension;
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already existed in ExtensionCollector.java as a private class. Instead, I have moved it into a separate file so that we can also leverage it for a bijection between uris and urns for migration purposes.


import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/** We don't depend on guava... */
public class BidiMap<T1, T2> {
Comment thread
benbellick marked this conversation as resolved.
private final Map<T1, T2> forwardMap;
private final Map<T2, T1> reverseMap;

BidiMap(Map<T1, T2> forwardMap) {
this.forwardMap = forwardMap;
this.reverseMap = new HashMap<>();
for (Map.Entry<T1, T2> entry : forwardMap.entrySet()) {
reverseMap.put(entry.getValue(), entry.getKey());
}
}

BidiMap() {
this.forwardMap = new HashMap<>();
this.reverseMap = new HashMap<>();
}

T2 get(T1 t1) {
return forwardMap.get(t1);
}

T1 reverseGet(T2 t2) {
return reverseMap.get(t2);
}

/**
* Associates the specified values in both directions. Throws if either value is already mapped to
* a different value.
*/
void put(T1 t1, T2 t2) {
T2 existingForward = forwardMap.get(t1);
T1 existingReverse = reverseMap.get(t2);

if (existingForward != null && !existingForward.equals(t2)) {
throw new IllegalArgumentException("Key already exists in map with different value");
}
if (existingReverse != null && !existingReverse.equals(t1)) {
throw new IllegalArgumentException("Key already exists in map with different value");
}

forwardMap.put(t1, t2);
reverseMap.put(t2, t1);
}

void merge(BidiMap<T1, T2> other) {
for (Map.Entry<T1, T2> entry : other.forwardEntrySet()) {
put(entry.getKey(), entry.getValue());
}
}

Set<Map.Entry<T1, T2>> forwardEntrySet() {
return forwardMap.entrySet();
}

Set<Map.Entry<T2, T1>> reverseEntrySet() {
return reverseMap.entrySet();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@
import java.util.stream.Collectors;

public class DefaultExtensionCatalog {
public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml";
public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
public static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "/functions_arithmetic_decimal.yaml";
public static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml";
public static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";
public static final String FUNCTIONS_DATETIME = "/functions_datetime.yaml";
public static final String FUNCTIONS_GEOMETRY = "/functions_geometry.yaml";
public static final String FUNCTIONS_LOGARITHMIC = "/functions_logarithmic.yaml";
public static final String FUNCTIONS_ROUNDING = "/functions_rounding.yaml";
public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml";
public static final String FUNCTIONS_SET = "/functions_set.yaml";
public static final String FUNCTIONS_STRING = "/functions_string.yaml";
public static final String FUNCTIONS_AGGREGATE_APPROX =
"extension:io.substrait:functions_aggregate_approx";
public static final String FUNCTIONS_AGGREGATE_GENERIC =
"extension:io.substrait:functions_aggregate_generic";
public static final String FUNCTIONS_ARITHMETIC = "extension:io.substrait:functions_arithmetic";
public static final String FUNCTIONS_ARITHMETIC_DECIMAL =
"extension:io.substrait:functions_arithmetic_decimal";
public static final String FUNCTIONS_BOOLEAN = "extension:io.substrait:functions_boolean";
public static final String FUNCTIONS_COMPARISON = "extension:io.substrait:functions_comparison";
public static final String FUNCTIONS_DATETIME = "extension:io.substrait:functions_datetime";
public static final String FUNCTIONS_GEOMETRY = "extension:io.substrait:functions_geometry";
public static final String FUNCTIONS_LOGARITHMIC = "extension:io.substrait:functions_logarithmic";
public static final String FUNCTIONS_ROUNDING = "extension:io.substrait:functions_rounding";
public static final String FUNCTIONS_ROUNDING_DECIMAL =
"extension:io.substrait:functions_rounding_decimal";
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";

public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
loadDefaultCollection();
Expand Down
Loading