diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 09706f115..97b9dfbf6 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -108,6 +108,8 @@ configurations[JavaPlugin.TEST_IMPLEMENTATION_CONFIGURATION_NAME].extendsFrom(sh dependencies { testImplementation(platform(libs.junit.bom)) + testImplementation(libs.protobuf.java.util) + testImplementation(libs.junit.jupiter) testRuntimeOnly(libs.junit.platform.launcher) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 4e8e428f7..ca6cd4ce3 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -699,7 +699,7 @@ public Expression.WindowFunctionInvocation windowFn( // Types public Type.UserDefined userDefinedType(String namespace, String typeName) { - return Type.UserDefined.builder().uri(namespace).name(typeName).nullable(false).build(); + return Type.UserDefined.builder().urn(namespace).name(typeName).nullable(false).build(); } // Misc diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index e5c45e19e..42c3c5118 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -666,13 +666,13 @@ public R accept( abstract class UserDefinedLiteral implements Literal { public abstract ByteString value(); - public abstract String uri(); + public abstract String urn(); public abstract String name(); @Override public Type getType() { - return Type.withNullability(nullable()).userDefined(uri(), name()); + return Type.withNullability(nullable()).userDefined(urn(), name()); } public static ImmutableExpression.UserDefinedLiteral.Builder builder() { diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 4a18026ab..adf157d7b 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -287,10 +287,10 @@ public static Expression.StructLiteral struct( } public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String uri, String name, Any value) { + boolean nullable, String urn, String name, Any value) { return Expression.UserDefinedLiteral.builder() .nullable(nullable) - .uri(uri) + .urn(urn) .name(name) .value(value.toByteString()) .build(); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 093e3fff3..caf145dfc 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -361,7 +361,7 @@ public Expression visit( public Expression visit( io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) { int typeReference = - extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return lit( bldr -> { try { diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 859ddfca5..8f95cdf07 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -495,7 +495,7 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue()); + literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index c658fcbce..175b2d705 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -28,6 +28,9 @@ public ProtoExtendedExpressionConverter() { } public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } this.extensionCollection = extensionCollection; } @@ -35,7 +38,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp // fill in simple extension information through a discovery in the current proto-extended // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder().from(extendedExpression).build(); + ImmutableExtensionLookup.builder(extensionCollection).from(extendedExpression).build(); NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); diff --git a/core/src/main/java/io/substrait/extension/BidiMap.java b/core/src/main/java/io/substrait/extension/BidiMap.java new file mode 100644 index 000000000..f0eeb30a7 --- /dev/null +++ b/core/src/main/java/io/substrait/extension/BidiMap.java @@ -0,0 +1,65 @@ +package io.substrait.extension; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** We don't depend on guava... */ +public class BidiMap { + private final Map forwardMap; + private final Map reverseMap; + + BidiMap(Map forwardMap) { + this.forwardMap = forwardMap; + this.reverseMap = new HashMap<>(); + for (Map.Entry entry : forwardMap.entrySet()) { + reverseMap.put(entry.getValue(), entry.getKey()); + } + } + + BidiMap() { + this.forwardMap = new HashMap<>(); + this.reverseMap = new HashMap<>(); + } + + T2 get(T1 t1) { + return forwardMap.get(t1); + } + + T1 reverseGet(T2 t2) { + return reverseMap.get(t2); + } + + /** + * Associates the specified values in both directions. Throws if either value is already mapped to + * a different value. + */ + void put(T1 t1, T2 t2) { + T2 existingForward = forwardMap.get(t1); + T1 existingReverse = reverseMap.get(t2); + + if (existingForward != null && !existingForward.equals(t2)) { + throw new IllegalArgumentException("Key already exists in map with different value"); + } + if (existingReverse != null && !existingReverse.equals(t1)) { + throw new IllegalArgumentException("Key already exists in map with different value"); + } + + forwardMap.put(t1, t2); + reverseMap.put(t2, t1); + } + + void merge(BidiMap other) { + for (Map.Entry entry : other.forwardEntrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + Set> forwardEntrySet() { + return forwardMap.entrySet(); + } + + Set> reverseEntrySet() { + return reverseMap.entrySet(); + } +} diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 39d28298f..89aad954e 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -5,19 +5,23 @@ import java.util.stream.Collectors; public class DefaultExtensionCatalog { - public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml"; - public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml"; - public static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml"; - public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "/functions_arithmetic_decimal.yaml"; - public static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml"; - public static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml"; - public static final String FUNCTIONS_DATETIME = "/functions_datetime.yaml"; - public static final String FUNCTIONS_GEOMETRY = "/functions_geometry.yaml"; - public static final String FUNCTIONS_LOGARITHMIC = "/functions_logarithmic.yaml"; - public static final String FUNCTIONS_ROUNDING = "/functions_rounding.yaml"; - public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml"; - public static final String FUNCTIONS_SET = "/functions_set.yaml"; - public static final String FUNCTIONS_STRING = "/functions_string.yaml"; + public static final String FUNCTIONS_AGGREGATE_APPROX = + "extension:io.substrait:functions_aggregate_approx"; + public static final String FUNCTIONS_AGGREGATE_GENERIC = + "extension:io.substrait:functions_aggregate_generic"; + public static final String FUNCTIONS_ARITHMETIC = "extension:io.substrait:functions_arithmetic"; + public static final String FUNCTIONS_ARITHMETIC_DECIMAL = + "extension:io.substrait:functions_arithmetic_decimal"; + public static final String FUNCTIONS_BOOLEAN = "extension:io.substrait:functions_boolean"; + public static final String FUNCTIONS_COMPARISON = "extension:io.substrait:functions_comparison"; + public static final String FUNCTIONS_DATETIME = "extension:io.substrait:functions_datetime"; + public static final String FUNCTIONS_GEOMETRY = "extension:io.substrait:functions_geometry"; + public static final String FUNCTIONS_LOGARITHMIC = "extension:io.substrait:functions_logarithmic"; + public static final String FUNCTIONS_ROUNDING = "extension:io.substrait:functions_rounding"; + public static final String FUNCTIONS_ROUNDING_DECIMAL = + "extension:io.substrait:functions_rounding_decimal"; + public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = loadDefaultCollection(); diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index d408c600b..7ad07a6b1 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -4,6 +4,7 @@ import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; @@ -19,14 +20,27 @@ public class ExtensionCollector extends AbstractExtensionLookup { private final BidiMap funcMap; private final BidiMap typeMap; + private final SimpleExtension.ExtensionCollection extensionCollection; // start at 0 to make sure functionAnchors start with 1 according to spec private int counter = 0; + private String getUriFromUrn(String urn) { + return extensionCollection.getUriFromUrn(urn); + } + public ExtensionCollector() { + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + public ExtensionCollector(SimpleExtension.ExtensionCollection extensionCollection) { super(new HashMap<>(), new HashMap<>()); + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } funcMap = new BidiMap<>(functionAnchorMap); typeMap = new BidiMap<>(typeAnchorMap); + this.extensionCollection = extensionCollection; } public int getFunctionReference(SimpleExtension.Function declaration) { @@ -52,6 +66,7 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); + builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensionUris(simpleExtensions.uris.values()); builder.addAllExtensions(simpleExtensions.extensionList); } @@ -59,89 +74,118 @@ public void addExtensionsToPlan(Plan.Builder builder) { public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); + builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensionUris(simpleExtensions.uris.values()); builder.addAllExtensions(simpleExtensions.extensionList); } private SimpleExtensions getExtensions() { + AtomicInteger urnPos = new AtomicInteger(1); AtomicInteger uriPos = new AtomicInteger(1); + HashMap urns = new HashMap<>(); HashMap uris = new HashMap<>(); ArrayList extensionList = new ArrayList<>(); - for (Map.Entry e : funcMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), + for (Map.Entry e : funcMap.forwardEntrySet()) { + String urn = e.getValue().urn(); + String uri = getUriFromUrn(urn); + + // Create URN entry + SimpleExtensionURN urnObj = + urns.computeIfAbsent( + urn, k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(urnPos.getAndIncrement()) + .setUrn(k) .build()); + + // Create URI entry if mapping exists + SimpleExtensionURI uriObj = null; + if (uri != null) { + uriObj = + uris.computeIfAbsent( + uri, + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + } + + // Create function declaration with both URN and URI references + SimpleExtensionDeclaration.ExtensionFunction.Builder funcBuilder = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUrnReference(urnObj.getExtensionUrnAnchor()); + + if (uriObj != null) { + funcBuilder.setExtensionUriReference(uriObj.getExtensionUriAnchor()); + } + SimpleExtensionDeclaration decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(funcBuilder).build(); extensionList.add(decl); } - for (Map.Entry e : typeMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), + + for (Map.Entry e : typeMap.forwardEntrySet()) { + String urn = e.getValue().urn(); + String uri = getUriFromUrn(urn); + + // Create URN entry + SimpleExtensionURN urnObj = + urns.computeIfAbsent( + urn, k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(urnPos.getAndIncrement()) + .setUrn(k) .build()); + + // Create URI entry if mapping exists + SimpleExtensionURI uriObj = null; + if (uri != null) { + uriObj = + uris.computeIfAbsent( + uri, + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + } + + // Create type declaration with both URN and URI references + SimpleExtensionDeclaration.ExtensionType.Builder typeBuilder = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUrnReference(urnObj.getExtensionUrnAnchor()); + + if (uriObj != null) { + typeBuilder.setExtensionUriReference(uriObj.getExtensionUriAnchor()); + } + SimpleExtensionDeclaration decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionType( - SimpleExtensionDeclaration.ExtensionType.newBuilder() - .setTypeAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); + SimpleExtensionDeclaration.newBuilder().setExtensionType(typeBuilder).build(); extensionList.add(decl); } - return new SimpleExtensions(uris, extensionList); + return new SimpleExtensions(urns, uris, extensionList); } private static final class SimpleExtensions { + final HashMap urns; final HashMap uris; final ArrayList extensionList; SimpleExtensions( + HashMap urns, HashMap uris, ArrayList extensionList) { + this.urns = urns; this.uris = uris; this.extensionList = extensionList; } } - - /** We don't depend on guava... */ - private static class BidiMap { - private final Map forwardMap; - private final Map reverseMap; - - public BidiMap(Map forwardMap) { - this.forwardMap = forwardMap; - this.reverseMap = new HashMap<>(); - } - - public T2 get(T1 t1) { - return forwardMap.get(t1); - } - - public T1 reverseGet(T2 t2) { - return reverseMap.get(t2); - } - - public void put(T1 t1, T2 t2) { - forwardMap.put(t1, t2); - reverseMap.put(t2, t1); - } - } } diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 6ac4fe922..9b546d9dc 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -3,7 +3,7 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; -import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -22,28 +22,208 @@ private ImmutableExtensionLookup( } public static Builder builder() { - return new Builder(); + return builder(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + public static Builder builder(SimpleExtension.ExtensionCollection extensionCollection) { + return new Builder(extensionCollection); } public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); + private final SimpleExtension.ExtensionCollection extensionCollection; + + public Builder(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } + this.extensionCollection = extensionCollection; + } + + /** + * Resolves URN from URI using the URI/URN mapping. + * + * @param uri The URI to resolve + * @return The corresponding URN, or null if no mapping exists + */ + private String resolveUrnFromUri(String uri) { + return extensionCollection.getUrnFromUri(uri); + } + + private SimpleExtension.FunctionAnchor resolveFunctionAnchor( + SimpleExtensionDeclaration.ExtensionFunction func, + Map urnMap, + Map uriMap) { + + // 1. Try non-zero URN reference + if (func.getExtensionUrnReference() != 0) { + String urnFromUrnRef = urnMap.get(func.getExtensionUrnReference()); + if (urnFromUrnRef == null) { + throw new IllegalStateException( + String.format( + "Function '%s' references URN anchor %d, but no URN is registered at that anchor", + func.getName(), func.getExtensionUrnReference())); + } + return SimpleExtension.FunctionAnchor.of(urnFromUrnRef, func.getName()); + } + + // 2. Try non-zero URI reference + if (func.getExtensionUriReference() != 0) { + String uriFromUriRef = uriMap.get(func.getExtensionUriReference()); + if (uriFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Function '%s' references URI anchor %d, but no URI is registered at that anchor", + func.getName(), func.getExtensionUriReference())); + } + String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + if (urnFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Function '%s' references URI anchor %d with URI '%s', but this URI could not be resolved to a URN. " + + "Ensure a URI <-> URN mapping is registered in the ExtensionCollection.", + func.getName(), func.getExtensionUriReference(), uriFromUriRef)); + } + return SimpleExtension.FunctionAnchor.of(urnFromUriRef, func.getName()); + } + + /* At this point, both URI and URN are known be 0. + * With protobufs, we cannot distinguish between 0 as an + * intentional value vs 0 as a default value. + * We perform some additional checks to below to handle this. + */ + + String urn = urnMap.get(func.getExtensionUrnReference()); + String uri = uriMap.get(func.getExtensionUriReference()); + + // 3. Try both 0 URI and 0 URN if both resolve + if (uri != null && urn != null) { + String resolvedUrn = resolveUrnFromUri(uri); + if (urn.equals(resolvedUrn)) { + return SimpleExtension.FunctionAnchor.of(urn, func.getName()); + } + throw new IllegalStateException( + String.format( + "Conflicting URI/URN mapping at reference 0: URI '%s' maps to URN '%s', but reference 0 also specifies URN '%s'. " + + "These must be consistent for proper resolution.", + uri, resolvedUrn, urn)); + } + + // 4. Try only 0 URN + if (urn != null) { + return SimpleExtension.FunctionAnchor.of(urn, func.getName()); + } + // 5. Try only 0 URI + if (uri != null && resolveUrnFromUri(uri) != null) { + return SimpleExtension.FunctionAnchor.of(resolveUrnFromUri(uri), func.getName()); + } + throw new IllegalStateException( + String.format( + "All resolution strategies failed for URI %s and URN %s (perhaps a URI <-> URN mapping was not registered during the migration) ", + uri, urn)); + } + + private SimpleExtension.TypeAnchor resolveTypeAnchor( + SimpleExtensionDeclaration.ExtensionType type, + Map urnMap, + Map uriMap) { + + // 1. Try non-zero URN reference + if (type.getExtensionUrnReference() != 0) { + String urnFromUrnRef = urnMap.get(type.getExtensionUrnReference()); + if (urnFromUrnRef == null) { + throw new IllegalStateException( + String.format( + "Type '%s' references URN anchor %d, but no URN is registered at that anchor", + type.getName(), type.getExtensionUrnReference())); + } + return SimpleExtension.TypeAnchor.of(urnFromUrnRef, type.getName()); + } + + // 2. Try non-zero URI reference + if (type.getExtensionUriReference() != 0) { + String uriFromUriRef = uriMap.get(type.getExtensionUriReference()); + if (uriFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Type '%s' references URI anchor %d, but no URI is registered at that anchor", + type.getName(), type.getExtensionUriReference())); + } + String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + if (urnFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Type '%s' references URI anchor %d with URI '%s', but this URI could not be resolved to a URN. " + + "Ensure a URI <-> URN mapping is registered in the ExtensionCollection.", + type.getName(), type.getExtensionUriReference(), uriFromUriRef)); + } + return SimpleExtension.TypeAnchor.of(urnFromUriRef, type.getName()); + } + + /* At this point, both URI and URN are known be 0. + * With protobufs, we cannot distinguish between 0 as an + * intentional value vs 0 as a default value. + * We perform some additional checks to below to handle this. + */ + + String urn = urnMap.get(type.getExtensionUrnReference()); + String uri = uriMap.get(type.getExtensionUriReference()); + + // 3. Try both 0 URI and 0 URN if both resolve + if (uri != null && urn != null) { + String resolvedUrn = resolveUrnFromUri(uri); + if (urn.equals(resolvedUrn)) { + return SimpleExtension.TypeAnchor.of(urn, type.getName()); + } + throw new IllegalStateException( + String.format( + "Conflicting URI/URN mapping at reference 0: URI '%s' maps to URN '%s', but reference 0 also specifies URN '%s'. " + + "These must be consistent for proper resolution.", + uri, resolvedUrn, urn)); + } + + // 4. Try only 0 URN + if (urn != null) { + return SimpleExtension.TypeAnchor.of(urn, type.getName()); + } + // 5. Try only 0 URI + if (uri != null && resolveUrnFromUri(uri) != null) { + return SimpleExtension.TypeAnchor.of(resolveUrnFromUri(uri), type.getName()); + } + throw new IllegalStateException( + String.format( + "All resolution strategies failed for URI %s and URN %s (perhaps a URI <-> URN mapping was not registered during the migration) ", + uri, urn)); + } public Builder from(Plan plan) { - return from(plan.getExtensionUrisList(), plan.getExtensionsList()); + return from( + plan.getExtensionUrnsList(), plan.getExtensionUrisList(), plan.getExtensionsList()); } public Builder from(ExtendedExpression extendedExpression) { return from( - extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()); + extendedExpression.getExtensionUrnsList(), + extendedExpression.getExtensionUrisList(), + extendedExpression.getExtensionsList()); } private Builder from( - List simpleExtensionURIs, + List simpleExtensionURNs, + List simpleExtensionURIs, List simpleExtensionDeclarations) { - Map namespaceMap = new HashMap<>(); - for (SimpleExtensionURI extension : simpleExtensionURIs) { - namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); + Map urnMap = new HashMap<>(); + Map uriMap = new HashMap<>(); + + // Handle URN format + for (SimpleExtensionURN extension : simpleExtensionURNs) { + urnMap.put(extension.getExtensionUrnAnchor(), extension.getUrn()); + } + + // Handle deprecated URI format + for (io.substrait.proto.SimpleExtensionURI extension : simpleExtensionURIs) { + uriMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap @@ -53,13 +233,7 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); int reference = func.getFunctionAnchor(); - String namespace = namespaceMap.get(func.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + func.getExtensionUriReference()); - } - String name = func.getName(); - SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name); + SimpleExtension.FunctionAnchor anchor = resolveFunctionAnchor(func, urnMap, uriMap); functionMap.put(reference, anchor); } @@ -70,13 +244,7 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); int reference = type.getTypeAnchor(); - String namespace = namespaceMap.get(type.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + type.getExtensionUriReference()); - } - String name = type.getName(); - SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name); + SimpleExtension.TypeAnchor anchor = resolveTypeAnchor(type, urnMap, uriMap); typeMap.put(reference, anchor); } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 14e11ca10..3198f7a26 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -1,10 +1,10 @@ package io.substrait.extension; import com.fasterxml.jackson.annotation.JacksonInject; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.InjectableValues; import com.fasterxml.jackson.databind.ObjectMapper; @@ -26,8 +26,11 @@ import java.util.Map; import java.util.Optional; import java.util.OptionalInt; +import java.util.Scanner; import java.util.Set; +import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -41,12 +44,28 @@ public class SimpleExtension { private static final Logger LOGGER = LoggerFactory.getLogger(SimpleExtension.class); - // Key for looking up URI in InjectableValues - public static final String URI_LOCATOR_KEY = "uri"; + // Key for looking up URN in InjectableValues + public static final String URN_LOCATOR_KEY = "urn"; - private static ObjectMapper objectMapper(String namespace) { + private static final Predicate URN_CHECKER = + Pattern.compile("^extension:[^:]+:[^:]+$").asPredicate(); + + // `\A` means beginning of input. Using it as a delimiter in a scanner reads in the whole file. + private static Pattern READ_WHOLE_FILE = Pattern.compile("\\A"); + + private static void validateUrn(String urn) { + if (urn == null || urn.trim().isEmpty()) { + throw new IllegalArgumentException("URN cannot be null or empty"); + } + if (!URN_CHECKER.test(urn)) { + throw new IllegalArgumentException( + "URN must follow format 'extension::', got: " + urn); + } + } + + private static ObjectMapper objectMapper(String urn) { InjectableValues.Std iv = new InjectableValues.Std(); - iv.addValue(URI_LOCATOR_KEY, namespace); + iv.addValue(URN_LOCATOR_KEY, urn); return new ObjectMapper(new YAMLFactory()) .enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @@ -184,25 +203,22 @@ public static ImmutableSimpleExtension.EnumArgument.Builder builder() { } public interface Anchor { - String namespace(); + String urn(); String key(); } @Value.Immutable public interface FunctionAnchor extends Anchor { - static FunctionAnchor of(String namespace, String key) { - return ImmutableSimpleExtension.FunctionAnchor.builder() - .namespace(namespace) - .key(key) - .build(); + static FunctionAnchor of(String urn, String key) { + return ImmutableSimpleExtension.FunctionAnchor.builder().urn(urn).key(key).build(); } } @Value.Immutable public interface TypeAnchor extends Anchor { - static TypeAnchor of(String namespace, String name) { - return ImmutableSimpleExtension.TypeAnchor.builder().namespace(namespace).key(name).build(); + static TypeAnchor of(String urn, String name) { + return ImmutableSimpleExtension.TypeAnchor.builder().urn(urn).key(name).build(); } } @@ -226,7 +242,7 @@ default ParameterConsistency parameterConsistency() { public abstract static class Function { private final Supplier anchorSupplier = - Util.memoize(() -> FunctionAnchor.of(uri(), key())); + Util.memoize(() -> FunctionAnchor.of(urn(), key())); private final Supplier keySupplier = Util.memoize(() -> constructKey(name(), args())); private final Supplier> requiredArgsSupplier = Util.memoize( @@ -241,8 +257,8 @@ public String name() { } @Value.Default - public String uri() { - // we can't use null detection here since we initially construct this without a uri, then + public String urn() { + // we can't use null detection here since we initially construct this without a urn, then // resolve later. return ""; } @@ -366,8 +382,8 @@ public abstract static class ScalarFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -375,9 +391,9 @@ public Stream resolve(String uri) { @JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class) @Value.Immutable public abstract static class ScalarFunctionVariant extends Function { - public ScalarFunctionVariant resolve(String uri, String name, String description) { + public ScalarFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -402,8 +418,8 @@ public abstract static class AggregateFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -419,8 +435,8 @@ public abstract static class WindowFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } public static ImmutableSimpleExtension.WindowFunction.Builder builder() { @@ -446,9 +462,9 @@ public String toString() { @Nullable public abstract TypeExpression intermediate(); - AggregateFunctionVariant resolve(String uri, String name, String description) { + AggregateFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.AggregateFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -488,9 +504,9 @@ public String toString() { return super.toString(); } - WindowFunctionVariant resolve(String uri, String name, String description) { + WindowFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.WindowFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -515,12 +531,12 @@ public static ImmutableSimpleExtension.WindowFunctionVariant.Builder builder() { @Value.Immutable public abstract static class Type { private final Supplier anchorSupplier = - Util.memoize(() -> TypeAnchor.of(uri(), name())); + Util.memoize(() -> TypeAnchor.of(urn(), name())); public abstract String name(); - @JacksonInject(SimpleExtension.URI_LOCATOR_KEY) - public abstract String uri(); + @JacksonInject(SimpleExtension.URN_LOCATOR_KEY) + public abstract String urn(); // TODO: Handle conversion of structure object to Named Struct representation protected abstract Optional structure(); @@ -532,11 +548,15 @@ public TypeAnchor getAnchor() { @JsonDeserialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) @JsonSerialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) + @JsonIgnoreProperties(ignoreUnknown = true) @Value.Immutable public abstract static class ExtensionSignatures { @JsonProperty("types") public abstract List types(); + @JsonProperty("urn") + public abstract String urn(); + @JsonProperty("scalar_functions") public abstract List scalars(); @@ -553,27 +573,27 @@ public int size() { + (windows() == null ? 0 : windows().size()); } - public Stream resolve(String uri) { + public Stream resolve(String urn) { return Stream.concat( Stream.concat( - scalars() == null ? Stream.of() : scalars().stream().flatMap(f -> f.resolve(uri)), + scalars() == null ? Stream.of() : scalars().stream().flatMap(f -> f.resolve(urn)), aggregates() == null ? Stream.of() - : aggregates().stream().flatMap(f -> f.resolve(uri))), - windows() == null ? Stream.of() : windows().stream().flatMap(f -> f.resolve(uri))); + : aggregates().stream().flatMap(f -> f.resolve(urn))), + windows() == null ? Stream.of() : windows().stream().flatMap(f -> f.resolve(urn))); } } @Value.Immutable public abstract static class ExtensionCollection { - private final Supplier> namespaceSupplier = + private final Supplier> urnSupplier = Util.memoize( () -> { return Stream.concat( Stream.concat( - scalarFunctions().stream().map(Function::uri), - aggregateFunctions().stream().map(Function::uri)), - windowFunctions().stream().map(Function::uri)) + scalarFunctions().stream().map(Function::urn), + aggregateFunctions().stream().map(Function::urn)), + windowFunctions().stream().map(Function::urn)) .collect(Collectors.toSet()); }); @@ -610,6 +630,11 @@ public abstract static class ExtensionCollection { Function::getAnchor, java.util.function.Function.identity())); }); + @Value.Default + BidiMap uriUrnMap() { + return new BidiMap<>(); + } + public abstract List types(); public abstract List scalarFunctions(); @@ -627,11 +652,11 @@ public Type getType(TypeAnchor anchor) { if (type != null) { return type; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected type with name %s. The namespace %s is loaded but no type with this name found.", - anchor.key(), anchor.namespace())); + "Unexpected type with name %s. The URN %s is loaded but no type with this name found.", + anchor.key(), anchor.urn())); } public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { @@ -639,16 +664,16 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { if (variant != null) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected scalar function with key %s. The namespace %s is loaded " + "Unexpected scalar function with key %s. The URN %s is loaded " + "but no scalar function with this key found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); } - private void checkNamespace(String name) { - if (namespaceSupplier.get().contains(name)) { + private void checkUrn(String name) { + if (urnSupplier.get().contains(name)) { return; } @@ -665,12 +690,12 @@ public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected aggregate function with key %s. The namespace %s is loaded " + "Unexpected aggregate function with key %s. The URN %s is loaded " + "but no aggregate function with this key was found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); } public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { @@ -678,15 +703,39 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { if (variant != null) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected window aggregate function with key %s. The namespace %s is loaded " + "Unexpected window aggregate function with key %s. The URN %s is loaded " + "but no window aggregate function with this key was found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); + } + + /** + * Gets the URI for a given URN. This is for internal framework use during URI/URN migration. + * + * @param urn The URN to look up + * @return The corresponding URI, or null if not found + */ + String getUriFromUrn(String urn) { + return uriUrnMap().reverseGet(urn); + } + + /** + * Gets the URN for a given URI. This is for internal framework use during URI/URN migration. + * + * @param uri The URI to look up + * @return The corresponding URN, or null if not found + */ + String getUrnFromUri(String uri) { + return uriUrnMap().get(uri); } public ExtensionCollection merge(ExtensionCollection extensionCollection) { + BidiMap mergedUriUrnMap = new BidiMap<>(); + mergedUriUrnMap.merge(uriUrnMap()); + mergedUriUrnMap.merge(extensionCollection.uriUrnMap()); + return ImmutableSimpleExtension.ExtensionCollection.builder() .addAllAggregateFunctions(aggregateFunctions()) .addAllAggregateFunctions(extensionCollection.aggregateFunctions()) @@ -696,6 +745,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { .addAllWindowFunctions(extensionCollection.windowFunctions()) .addAllTypes(types()) .addAllTypes(extensionCollection.types()) + .uriUrnMap(mergedUriUrnMap) .build(); } } @@ -723,41 +773,61 @@ public static ExtensionCollection load(List resourcePaths) { return complete; } - public static ExtensionCollection load(String namespace, String str) { + public static ExtensionCollection load(String uri, String content) { try { - ExtensionSignatures doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class); - return buildExtensionCollection(namespace, doc); - } catch (JsonProcessingException e) { + if (uri == null || uri.isEmpty()) { + throw new IllegalArgumentException("URI cannot be null or empty"); + } + + // Parse with basic YAML mapper first to extract URN + ObjectMapper basicYamlMapper = new ObjectMapper(new YAMLFactory()); + com.fasterxml.jackson.databind.JsonNode rootNode = basicYamlMapper.readTree(content); + com.fasterxml.jackson.databind.JsonNode urnNode = rootNode.get("urn"); + if (urnNode == null) { + throw new IllegalArgumentException("Extension YAML file must contain a 'urn' field"); + } + String urn = urnNode.asText(); + validateUrn(urn); + + ExtensionSignatures docWithoutUri = + objectMapper(urn).readValue(content, ExtensionSignatures.class); + + ExtensionSignatures doc = + ImmutableSimpleExtension.ExtensionSignatures.builder().from(docWithoutUri).build(); + + return buildExtensionCollection(uri, doc); + } catch (IOException e) { throw new IllegalStateException(e); } } - public static ExtensionCollection load(String namespace, InputStream stream) { - try { - ExtensionSignatures doc = - objectMapper(namespace).readValue(stream, ExtensionSignatures.class); - return buildExtensionCollection(namespace, doc); - } catch (RuntimeException ex) { - throw ex; - } catch (Exception ex) { - throw new IllegalStateException("Failure while parsing " + namespace, ex); + public static ExtensionCollection load(String uri, InputStream stream) { + try (Scanner scanner = new Scanner(stream)) { + scanner.useDelimiter(READ_WHOLE_FILE); + String content = scanner.next(); + return load(uri, content); } } public static ExtensionCollection buildExtensionCollection( - String namespace, ExtensionSignatures extensionSignatures) { + String uri, ExtensionSignatures extensionSignatures) { + String urn = extensionSignatures.urn(); + validateUrn(urn); + if (uri == null || uri == "") { + throw new IllegalArgumentException("URI cannot be null or empty"); + } List scalarFunctionVariants = extensionSignatures.scalars().stream() - .flatMap(t -> t.resolve(namespace)) + .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); List aggregateFunctionVariants = extensionSignatures.aggregates().stream() - .flatMap(t -> t.resolve(namespace)) + .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); Stream windowFunctionVariants = - extensionSignatures.windows().stream().flatMap(t -> t.resolve(namespace)); + extensionSignatures.windows().stream().flatMap(t -> t.resolve(urn)); // Aggregate functions can be used as Window Functions Stream windowAggFunctionVariants = @@ -778,18 +848,23 @@ public static ExtensionCollection buildExtensionCollection( Stream.concat(windowFunctionVariants, windowAggFunctionVariants) .collect(Collectors.toList()); + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put(uri, urn); + ImmutableSimpleExtension.ExtensionCollection collection = ImmutableSimpleExtension.ExtensionCollection.builder() .scalarFunctions(scalarFunctionVariants) .aggregateFunctions(aggregateFunctionVariants) .windowFunctions(allWindowFunctionVariants) .addAllTypes(extensionSignatures.types()) + .uriUrnMap(uriUrnMap) .build(); + LOGGER.atDebug().log( "Loaded {} aggregate functions and {} scalar functions from {}.", collection.aggregateFunctions().size(), collection.scalarFunctions().size(), - namespace); + extensionSignatures.urn()); return collection; } } diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index 5b8e2599e..5785ccd70 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -1,6 +1,8 @@ package io.substrait.plan; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; import io.substrait.proto.Rel; @@ -12,9 +14,22 @@ /** Converts from {@link io.substrait.plan.Plan} to {@link io.substrait.proto.Plan} */ public class PlanProtoConverter { + private final SimpleExtension.ExtensionCollection extensionCollection; + + public PlanProtoConverter() { + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + public PlanProtoConverter(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } + this.extensionCollection = extensionCollection; + } + public Plan toProto(io.substrait.plan.Plan plan) { List planRels = new ArrayList<>(); - ExtensionCollector functionCollector = new ExtensionCollector(); + ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection); for (io.substrait.plan.Plan.Root root : plan.getRoots()) { Rel input = new RelProtoConverter(functionCollector).toProto(root.getInput()); planRels.add( diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 15145fe5c..596398c5b 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -21,6 +21,9 @@ public ProtoPlanConverter() { } public ProtoPlanConverter(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } this.extensionCollection = extensionCollection; } @@ -30,7 +33,8 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/main/java/io/substrait/type/Deserializers.java b/core/src/main/java/io/substrait/type/Deserializers.java index 160004a3b..efdbc3306 100644 --- a/core/src/main/java/io/substrait/type/Deserializers.java +++ b/core/src/main/java/io/substrait/type/Deserializers.java @@ -44,7 +44,7 @@ public T deserialize(final JsonParser p, final DeserializationContext ctxt) String typeString = p.getValueAsString(); try { String namespace = - (String) ctxt.findInjectableValue(SimpleExtension.URI_LOCATOR_KEY, null, null); + (String) ctxt.findInjectableValue(SimpleExtension.URN_LOCATOR_KEY, null, null); return TypeStringParser.parse(typeString, namespace, converter); } catch (Exception ex) { throw JsonMappingException.from( diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 362adc3a8..aaf97aa12 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -389,7 +389,7 @@ public R accept(final TypeVisitor typeVisitor) th @Value.Immutable abstract class UserDefined implements Type { - public abstract String uri(); + public abstract String urn(); public abstract String name(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 880b72ed9..43358e505 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -108,8 +108,8 @@ public Type.Map map(Type key, Type value) { return Type.Map.builder().nullable(nullable).key(key).value(value).build(); } - public Type userDefined(String uri, String name) { - return Type.UserDefined.builder().nullable(nullable).uri(uri).name(name).build(); + public Type userDefined(String urn, String name) { + return Type.UserDefined.builder().nullable(nullable).urn(urn).name(name).build(); } public static TypeCreator of(boolean nullability) { diff --git a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java index 07093e925..8085ea6b8 100644 --- a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java +++ b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java @@ -16,16 +16,16 @@ public class TypeStringParser { private TypeStringParser() {} - public static Type parseSimple(String str, String namespace) { - return parse(str, namespace, ParseToPojo::type); + public static Type parseSimple(String str, String urn) { + return parse(str, urn, ParseToPojo::type); } - public static ParameterizedType parseParameterized(String str, String namespace) { - return parse(str, namespace, ParseToPojo::parameterizedType); + public static ParameterizedType parseParameterized(String str, String urn) { + return parse(str, urn, ParseToPojo::parameterizedType); } - public static TypeExpression parseExpression(String str, String namespace) { - return parse(str, namespace, ParseToPojo::typeExpression); + public static TypeExpression parseExpression(String str, String urn) { + return parse(str, urn, ParseToPojo::typeExpression); } private static SubstraitTypeParser.StartContext parse(String str) { @@ -40,8 +40,8 @@ private static SubstraitTypeParser.StartContext parse(String str) { } public static T parse( - String str, String namespace, BiFunction func) { - return func.apply(namespace, parse(str)); + String str, String urn, BiFunction func) { + return func.apply(urn, parse(str)); } public static TypeExpression parse(String str, ParseToPojo.Visitor visitor) { diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index a8c64db95..691d4bce5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -164,7 +164,7 @@ public final T visit(final Type.Map expr) { @Override public final T visit(final Type.UserDefined expr) { int ref = - extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return typeContainer(expr).userDefined(ref); } } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 661f57fea..95d42328a 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -90,7 +90,7 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.uri(), t.name()); + return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java new file mode 100644 index 000000000..d8de8a619 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java @@ -0,0 +1,87 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class ExtensionCollectionMergeTest { + + @Test + public void testMergeCollectionsWithDifferentUriUrnMappings() { + String yaml1 = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:ns1:collection1\n" + + "scalar_functions:\n" + + " - name: func1\n" + + " impls:\n" + + " - args: []\n" + + " return: boolean\n"; + + String yaml2 = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:ns2:collection2\n" + + "scalar_functions:\n" + + " - name: func2\n" + + " impls:\n" + + " - args: []\n" + + " return: i32\n"; + + SimpleExtension.ExtensionCollection collection1 = + SimpleExtension.load("uri1://extensions", yaml1); + SimpleExtension.ExtensionCollection collection2 = + SimpleExtension.load("uri2://extensions", yaml2); + + SimpleExtension.ExtensionCollection merged = collection1.merge(collection2); + + assertEquals("extension:ns1:collection1", merged.getUrnFromUri("uri1://extensions")); + assertEquals("extension:ns2:collection2", merged.getUrnFromUri("uri2://extensions")); + assertEquals("uri1://extensions", merged.getUriFromUrn("extension:ns1:collection1")); + assertEquals("uri2://extensions", merged.getUriFromUrn("extension:ns2:collection2")); + + assertTrue(merged.scalarFunctions().size() >= 2); + } + + @Test + public void testMergeCollectionsWithIdenticalMappings() { + String yaml = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:shared:extension\n" + + "scalar_functions:\n" + + " - name: shared_func\n" + + " impls:\n" + + " - args: []\n" + + " return: boolean\n"; + + SimpleExtension.ExtensionCollection collection1 = SimpleExtension.load("shared://uri", yaml); + SimpleExtension.ExtensionCollection collection2 = SimpleExtension.load("shared://uri", yaml); + + SimpleExtension.ExtensionCollection merged = + assertDoesNotThrow(() -> collection1.merge(collection2)); + + assertEquals("extension:shared:extension", merged.getUrnFromUri("shared://uri")); + assertEquals("shared://uri", merged.getUriFromUrn("extension:shared:extension")); + } + + @Test + public void testMergeCollectionsWithConflictingMappings() { + String yaml1 = + "%YAML 1.2\n" + "---\n" + "urn: extension:conflict:urn1\n" + "scalar_functions: []\n"; + + String yaml2 = + "%YAML 1.2\n" + "---\n" + "urn: extension:conflict:urn2\n" + "scalar_functions: []\n"; + + SimpleExtension.ExtensionCollection collection1 = SimpleExtension.load("conflict://uri", yaml1); + SimpleExtension.ExtensionCollection collection2 = + SimpleExtension.load("conflict://uri", yaml2); // Same URI, different URN + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> collection1.merge(collection2)); + assertTrue(exception.getMessage().contains("Key already exists in map with different value")); + } +} diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java new file mode 100644 index 000000000..48f51cbbd --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java @@ -0,0 +1,60 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class ExtensionCollectionUriUrnTest { + + @Test + public void testHasUrnAndHasUri() { + String yamlContent = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:exists\n" + + "scalar_functions:\n" + + " - name: test_function\n"; + + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("file:///tmp/test.yaml", yamlContent); + + assertTrue(collection.getUrnFromUri("file:///tmp/test.yaml") != null); + assertTrue(collection.getUriFromUrn("extension:test:exists") != null); + assertFalse(collection.getUrnFromUri("nonexistent://uri") != null); + assertFalse(collection.getUriFromUrn("extension:nonexistent:urn") != null); + } + + @Test + public void testGetNonexistentMappings() { + String yamlContent = + "%YAML 1.2\n" + "---\n" + "urn: extension:test:minimal\n" + "scalar_functions: []\n"; + + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("minimal://extension", yamlContent); + + assertNull(collection.getUrnFromUri("nonexistent://uri")); + assertNull(collection.getUriFromUrn("extension:nonexistent:urn")); + } + + @Test + public void testEmptyUriThrowsException() { + String yamlContent = + "%YAML 1.2\n" + "---\n" + "urn: extension:test:empty\n" + "scalar_functions: []\n"; + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load("", yamlContent)); + assertTrue(exception.getMessage().contains("URI cannot be null or empty")); + } + + @Test + public void testNullUriThrowsException() { + String yamlContent = + "%YAML 1.2\n" + "---\n" + "urn: extension:test:null\n" + "scalar_functions: []\n"; + + // The system throws NPE when null is passed, which is expected behavior + assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load(null, yamlContent)); + } +} diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java new file mode 100644 index 000000000..4e6b52fdf --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java @@ -0,0 +1,41 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.proto.Plan; +import org.junit.jupiter.api.Test; + +public class ExtensionCollectorUriUrnTest { + + @Test + public void testExtensionCollectorScalarFuncWithoutURI() { + String uri = "test://uri"; + BidiMap uriUrnMap = new BidiMap(); + uriUrnMap.put(uri, "extension:test:basic"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + ExtensionCollector collector = new ExtensionCollector(extensionCollection); + + SimpleExtension.ScalarFunctionVariant func = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:test:basic") + .name("test_func") + .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) + .build(); + + int functionRef = collector.getFunctionReference(func); + assertEquals(1, functionRef); + + Plan.Builder planBuilder = Plan.newBuilder(); + collector.addExtensionsToPlan(planBuilder); + + Plan plan = planBuilder.build(); + assertEquals(1, plan.getExtensionUrnsCount()); + assertEquals("extension:test:basic", plan.getExtensionUrns(0).getUrn()); + + assertEquals(1, plan.getExtensionUrisCount()); + assertEquals("test://uri", plan.getExtensionUris(0).getUri()); + } +} diff --git a/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java new file mode 100644 index 000000000..010bdfef4 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java @@ -0,0 +1,621 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.proto.Plan; +import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; +import org.junit.jupiter.api.Test; + +public class ImmutableExtensionLookupUriUrnTest { + + @Test + public void testUrnResolutionWorks() { + // Create URN-only plan (normal case) + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(1) + .setUrn("extension:test:urn") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("test_func") + .setExtensionUrnReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + // Test with no ExtensionCollection (no URI/URN mapping available) + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:urn", lookup.functionAnchorMap.get(1).urn()); + assertEquals("test_func", lookup.functionAnchorMap.get(1).key()); + } + + @Test + public void testUriToUrnFallbackWorks() { + // Create an ExtensionCollection with URI/URN mapping + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/extensions/test", "extension:test:mapped"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + // Create URI-only plan (legacy case) + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/extensions/test") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("legacy_func") + .setExtensionUriReference(1) // References the URI anchor (deprecated field) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + // Test with URI/URN mapping - should resolve URI to URN + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:mapped", lookup.functionAnchorMap.get(1).urn()); + assertEquals("legacy_func", lookup.functionAnchorMap.get(1).key()); + } + + @Test + public void testUriWithoutMappingThrowsError() { + // Create URI-only plan without mapping + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/unmapped") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("unmapped_func") + .setExtensionUriReference(1) // References the URI anchor + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + // Should throw error - URI present but no mapping available + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder().from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("could not be resolved to a URN")); + assertTrue(exception.getMessage().contains("http://example.com/unmapped")); + assertTrue(exception.getMessage().contains("URI <-> URN mapping")); + } + + @Test + public void testMissingUrnAndUriThrowsError() { + // Create plan with missing URN/URI reference + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("missing_func") + .setExtensionUrnReference(999) // Non-existent reference + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensions(decl).build(); + + // Should throw error - neither URN nor URI found + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder().from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("no URN is registered at that anchor")); + assertTrue(exception.getMessage().contains("999")); // The missing anchor reference + } + + // ========================================================================== + // Simple tests for all 5 resolution cases - Functions + // ========================================================================== + + @Test + public void testFunctionCase1_NonZeroUrnReference() { + // Case 1: Non-zero URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(1) + .setUrn("extension:test:case1") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case1_func") + .setExtensionUrnReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case1", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case1_func", lookup.functionAnchorMap.get(1).key()); + } + + @Test + public void testFunctionCase2_NonZeroUriReference() { + // Case 2: Non-zero URI reference resolves via mapping + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case2", "extension:test:case2"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/case2") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case2_func") + .setExtensionUriReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case2", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case2_func", lookup.functionAnchorMap.get(1).key()); + } + + @Test + public void testFunctionCase3_ZeroBothResolveConsistent() { + // Case 3: Both 0 references resolve to consistent URN + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case3", "extension:test:case3"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case3") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case3") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case3_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case3", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case3_func", lookup.functionAnchorMap.get(1).key()); + } + + @Test + public void testFunctionCase3_ZeroBothResolveConflict() { + // Case 3: Both 0 references resolve but to different URNs - should throw + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/conflict", "extension:test:different"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:original") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/conflict") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("conflict_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("Conflicting URI/URN mapping")); + assertTrue(exception.getMessage().contains("These must be consistent")); + } + + @Test + public void testFunctionCase4_ZeroUrnOnly() { + // Case 4: Only 0 URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case4") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case4_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case4", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case4_func", lookup.functionAnchorMap.get(1).key()); + } + + @Test + public void testFunctionCase5_ZeroUriOnly() { + // Case 5: Only 0 URI reference resolves + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case5", "extension:test:case5"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case5") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case5_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case5", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case5_func", lookup.functionAnchorMap.get(1).key()); + } + + // ========================================================================== + // Simple tests for all 5 resolution cases - Types + // ========================================================================== + + @Test + public void testTypeCase1_NonZeroUrnReference() { + // Case 1: Non-zero URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(1) + .setUrn("extension:test:case1") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case1_type") + .setExtensionUrnReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case1", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case1_type", lookup.typeAnchorMap.get(1).key()); + } + + @Test + public void testTypeCase2_NonZeroUriReference() { + // Case 2: Non-zero URI reference resolves via mapping + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case2", "extension:test:case2"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/case2") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case2_type") + .setExtensionUriReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case2", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case2_type", lookup.typeAnchorMap.get(1).key()); + } + + @Test + public void testTypeCase3_ZeroBothResolveConsistent() { + // Case 3: Both 0 references resolve to consistent URN + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case3", "extension:test:case3"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case3") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case3") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case3_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case3", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case3_type", lookup.typeAnchorMap.get(1).key()); + } + + @Test + public void testTypeCase3_ZeroBothResolveConflict() { + // Case 3: Both 0 references resolve but to different URNs - should throw + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/conflict", "extension:test:different"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:original") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/conflict") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("conflict_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("Conflicting URI/URN mapping")); + assertTrue(exception.getMessage().contains("These must be consistent")); + } + + @Test + public void testTypeCase4_ZeroUrnOnly() { + // Case 4: Only 0 URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case4") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case4_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case4", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case4_type", lookup.typeAnchorMap.get(1).key()); + } + + @Test + public void testTypeCase5_ZeroUriOnly() { + // Case 5: Only 0 URI reference resolves + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case5", "extension:test:case5"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case5") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case5_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case5", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case5_type", lookup.typeAnchorMap.get(1).key()); + } + + @Test + public void testTypeUriToUrnFallbackWorks() { + // Test the same logic but for types instead of functions + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/types/test", "extension:types:mapped"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/types/test") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("legacy_type") + .setExtensionUriReference(1) // References the URI anchor + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:types:mapped", lookup.typeAnchorMap.get(1).urn()); + assertEquals("legacy_type", lookup.typeAnchorMap.get(1).key()); + } +} diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index cd2522090..2e8fd8403 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -28,13 +28,13 @@ public class TypeExtensionTest { static final TypeCreator R = TypeCreator.of(false); - static final String NAMESPACE = "/custom_extensions"; + static final String NAMESPACE = "extension:test:custom_extensions"; final SimpleExtension.ExtensionCollection extensionCollection; { - InputStream inputStream = - this.getClass().getResourceAsStream("/extensions/custom_extensions.yaml"); - extensionCollection = SimpleExtension.load(NAMESPACE, inputStream); + String path = "/extensions/custom_extensions.yaml"; + InputStream inputStream = this.getClass().getResourceAsStream(path); + extensionCollection = SimpleExtension.load(path, inputStream); } final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); diff --git a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java new file mode 100644 index 000000000..ca2334c9c --- /dev/null +++ b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java @@ -0,0 +1,110 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.plan.PlanProtoConverter; +import io.substrait.plan.ProtoPlanConverter; +import io.substrait.proto.Plan; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +/** + * End-to-end tests demonstrating the full URI/URN migration workflow: 1. Consume plans with mixed + * URI/URN references 2. Convert proto -> POJO using ImmutableExtensionLookup with URI/URN mapping + * 3. Convert POJO -> proto using PlanProtoConverter 4. Verify output contains proper + * extensioninformation + */ +public class UriUrnMigrationEndToEndTest { + + /** Load a proto Plan from a JSON resource file using JsonFormat */ + private Plan loadPlanFromJson(String resourcePath) throws IOException { + try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream(resourcePath)) { + if (inputStream == null) { + throw new IOException("Resource not found: " + resourcePath); + } + + String jsonContent = + new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + .lines() + .collect(Collectors.joining("\n")); + + Plan.Builder planBuilder = Plan.newBuilder(); + JsonFormat.parser().merge(jsonContent, planBuilder); + return planBuilder.build(); + } + } + + @Test + public void testUriUrnMigrationEndToEnd() throws IOException { + + // List of (inputPath, expectedPath, extensionCollection) tuples + List testCases = + Arrays.asList( + new String[] { + "uri-urn-migration/uri-only-input-plan.json", + "uri-urn-migration/uri-only-expected-plan.json" + }, + new String[] { + "uri-urn-migration/complex-input-plan.json", + "uri-urn-migration/complex-expected-plan.json" + }, + new String[] { + "uri-urn-migration/urn-only-input-plan.json", + "uri-urn-migration/urn-only-expected-plan.json" + }, + new String[] { + "uri-urn-migration/mixed-partial-coverage-input-plan.json", + "uri-urn-migration/mixed-partial-coverage-expected-plan.json" + }, + new String[] { + "uri-urn-migration/zero-urn-resolution-input-plan.json", + "uri-urn-migration/zero-urn-resolution-expected-plan.json" + }); + + for (String[] testCase : testCases) { + String inputPath = testCase[0]; + String expectedPath = testCase[1]; + + Plan inputPlan = loadPlanFromJson(inputPath); + Plan expectedPlan = loadPlanFromJson(expectedPath); + + ProtoPlanConverter protoToPojo = + new ProtoPlanConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); + io.substrait.plan.Plan pojoPlan = protoToPojo.from(inputPlan); + + PlanProtoConverter pojoToProto = + new PlanProtoConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); + Plan actualPlan = pojoToProto.toProto(pojoPlan); + + assertEquals(expectedPlan, actualPlan); + } + } + + @Test + public void testUnresolvableUriThrowsException() throws IOException { + Plan inputPlan = loadPlanFromJson("uri-urn-migration/unresolvable-uri-plan.json"); + + ProtoPlanConverter protoToPojo = + new ProtoPlanConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + protoToPojo.from(inputPlan); + }); + + assertTrue(exception.getMessage().contains("could not be resolved to a URN")); + assertTrue(exception.getMessage().contains("/functions_nonexistent.yaml")); + } +} diff --git a/core/src/test/java/io/substrait/extension/UrnValidationTest.java b/core/src/test/java/io/substrait/extension/UrnValidationTest.java new file mode 100644 index 000000000..983eb4f41 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/UrnValidationTest.java @@ -0,0 +1,60 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class UrnValidationTest { + + @Test + public void testMissingUrnThrowsException() { + String yamlWithoutUrn = "%YAML 1.2\n" + "---\n" + "scalar_functions:\n" + " - name: test\n"; + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> SimpleExtension.load("some/uri", yamlWithoutUrn)); + assertTrue(exception.getMessage().contains("Extension YAML file must contain a 'urn' field")); + } + + @Test + public void testInvalidUrnFormatThrowsException() { + String yamlWithInvalidUrn = + "%YAML 1.2\n" + + "---\n" + + "urn: invalid:format\n" + + "scalar_functions:\n" + + " - name: test\n"; + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> SimpleExtension.load("some/uri", yamlWithInvalidUrn)); + assertTrue( + exception.getMessage().contains("URN must follow format 'extension::'")); + } + + @Test + public void testValidUrnWorks() { + String yamlWithValidUrn = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:valid\n" + + "scalar_functions:\n" + + " - name: test\n"; + assertDoesNotThrow(() -> SimpleExtension.load("some/uri", yamlWithValidUrn)); + } + + @Test + public void testUriUrnMapIsPopulated() { + String yamlWithValidUrn = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:valid\n" + + "scalar_functions:\n" + + " - name: test\n"; + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("test://uri", yamlWithValidUrn); + assertEquals("extension:test:valid", collection.getUrnFromUri("test://uri")); + } +} diff --git a/core/src/test/resources/extensions/custom_extensions.yaml b/core/src/test/resources/extensions/custom_extensions.yaml index 204a5f9ac..4776312ac 100644 --- a/core/src/test/resources/extensions/custom_extensions.yaml +++ b/core/src/test/resources/extensions/custom_extensions.yaml @@ -1,5 +1,6 @@ %YAML 1.2 --- +urn: extension:test:custom_extensions types: - name: "customType1" - name: "customType2" diff --git a/core/src/test/resources/uri-urn-migration/complex-expected-plan.json b/core/src/test/resources/uri-urn-migration/complex-expected-plan.json new file mode 100644 index 000000000..09833ace9 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/complex-expected-plan.json @@ -0,0 +1,162 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 3, + "name": "multiply:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 5 + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 7 + } + } + } + ] + } + }, + { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 3 + } + } + }, + { + "value": { + "literal": { + "i32": 4 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result", + "product" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/complex-input-plan.json b/core/src/test/resources/uri-urn-migration/complex-input-plan.json new file mode 100644 index 000000000..da28cbb33 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/complex-input-plan.json @@ -0,0 +1,148 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 0, + "uri": "/functions_string.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 2, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUrnReference": 2 + } + }, + { + "extensionFunction": { + "functionAnchor": 3, + "name": "multiply:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 5 + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 7 + } + } + } + ] + } + }, + { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 3 + } + } + }, + { + "value": { + "literal": { + "i32": 4 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result", + "product" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-expected-plan.json b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-expected-plan.json new file mode 100644 index 000000000..2898243b3 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-expected-plan.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "multiply:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "x", + "y" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 2 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + } + ] + } + } + ] + } + }, + "names": [ + "calculation" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-input-plan.json b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-input-plan.json new file mode 100644 index 000000000..e02019925 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-input-plan.json @@ -0,0 +1,110 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "multiply:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "x", + "y" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 2 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + } + ] + } + } + ] + } + }, + "names": [ + "calculation" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/unresolvable-uri-plan.json b/core/src/test/resources/uri-urn-migration/unresolvable-uri-plan.json new file mode 100644 index 000000000..769d359e7 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/unresolvable-uri-plan.json @@ -0,0 +1,73 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_nonexistent.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "nonexistent_function:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 7 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/uri-only-expected-plan.json b/core/src/test/resources/uri-urn-migration/uri-only-expected-plan.json new file mode 100644 index 000000000..1ad0150e3 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/uri-only-expected-plan.json @@ -0,0 +1,98 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 7 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/uri-only-input-plan.json b/core/src/test/resources/uri-urn-migration/uri-only-input-plan.json new file mode 100644 index 000000000..5630a00ca --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/uri-only-input-plan.json @@ -0,0 +1,80 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 7 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/urn-only-expected-plan.json b/core/src/test/resources/uri-urn-migration/urn-only-expected-plan.json new file mode 100644 index 000000000..9e47d1e74 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/urn-only-expected-plan.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "value1", + "value2" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "numbers" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 20 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/urn-only-input-plan.json b/core/src/test/resources/uri-urn-migration/urn-only-input-plan.json new file mode 100644 index 000000000..b5f5af06e --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/urn-only-input-plan.json @@ -0,0 +1,110 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "value1", + "value2" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "numbers" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 20 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/zero-urn-resolution-expected-plan.json b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-expected-plan.json new file mode 100644 index 000000000..b8788c850 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-expected-plan.json @@ -0,0 +1,104 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a", + "b" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/zero-urn-resolution-input-plan.json b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-input-plan.json new file mode 100644 index 000000000..bc4983846 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-input-plan.json @@ -0,0 +1,85 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 0, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 0, + "extensionUrnReference": 0 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "a", + "b" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index c0e1796da..3406de7de 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -71,7 +71,7 @@ public class CallConverters { Type.UserDefined t = (Type.UserDefined) type; return Expression.UserDefinedLiteral.builder() - .uri(t.uri()) + .urn(t.urn()) .name(t.name()) .value(literal.value()) .build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index cec68fcb7..b69ef9b02 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -185,7 +185,7 @@ private static ArgAnchor argAnchor(String fnNS, String fnSig, int argIdx) { private static ArgAnchor argAnchor(SimpleExtension.Function fnDef, int argIdx) { return new ArgAnchor( - SimpleExtension.FunctionAnchor.of(fnDef.getAnchor().namespace(), fnDef.getAnchor().key()), + SimpleExtension.FunctionAnchor.of(fnDef.getAnchor().urn(), fnDef.getAnchor().key()), argIdx); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java index a12a1eac6..176552f7d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java @@ -37,7 +37,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override public RelDataType toCalcite(Type.UserDefined type) { - if (type.uri().equals(uTypeURI) && type.name().equals(uTypeName)) { + if (type.urn().equals(uTypeURI) && type.name().equals(uTypeName)) { return uTypeFactory.createCalcite(type.nullable()); } return null; diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index fabddf56e..d351b1bca 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -43,7 +43,7 @@ public class CustomFunctionTest extends PlanTestBase { // Define custom functions in a "functions_custom.yaml" extension - static final String NAMESPACE = "/functions_custom"; + static final String NAMESPACE = "extension:substrait:functions_custom"; static final String FUNCTIONS_CUSTOM; static { @@ -56,7 +56,7 @@ public class CustomFunctionTest extends PlanTestBase { // Load custom extension into an ExtensionCollection static final SimpleExtension.ExtensionCollection extensionCollection = - SimpleExtension.load("/functions_custom", FUNCTIONS_CUSTOM); + SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); @@ -84,7 +84,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override public RelDataType toCalcite(Type.UserDefined type) { - if (type.uri().equals(NAMESPACE)) { + if (type.urn().equals(NAMESPACE)) { if (type.name().equals(aTypeName)) { return aTypeFactory.createCalcite(type.nullable()); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index ceb01e9e3..b10fbab09 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -29,14 +29,15 @@ public class RelCopyOnWriteVisitorTest extends PlanTestBase { public static SimpleExtension.FunctionAnchor APPROX_COUNT_DISTINCT = SimpleExtension.FunctionAnchor.of( - "/functions_aggregate_approx.yaml", "approx_count_distinct:any"); + "extension:io.substrait:functions_aggregate_approx", "approx_count_distinct:any"); public static SimpleExtension.FunctionAnchor COUNT = - SimpleExtension.FunctionAnchor.of("/functions_aggregate_generic.yaml", "count:any"); + SimpleExtension.FunctionAnchor.of( + "extension:io.substrait:functions_aggregate_generic", "count:any"); private static final String COUNT_DISTINCT_SUBBQUERY = "select\n" + " count(distinct l.l_orderkey),\n" - + " count(distinct l.l_orderkey) + 1,\n" + + " count(distinct l.l_orderkey) + 1,\n" + " sum(l.l_extendedprice * (1 - l.l_discount)) as revenue,\n" + " o.o_orderdate,\n" + " count(distinct o.o_shippriority)\n" diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 99fbc7d09..2c90f133d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -13,11 +13,11 @@ public class UserTypeFactory { private final InnerType N; private final InnerType R; - private final String uri; + private final String urn; private final String name; - public UserTypeFactory(String uri, String name) { - this.uri = uri; + public UserTypeFactory(String urn, String name) { + this.urn = urn; this.name = name; this.N = new InnerType(true, name); this.R = new InnerType(false, name); @@ -32,7 +32,7 @@ public RelDataType createCalcite(boolean nullable) { } public Type createSubstrait(boolean nullable) { - return TypeCreator.of(nullable).userDefined(uri, name); + return TypeCreator.of(nullable).userDefined(urn, name); } public boolean isTypeFromFactory(RelDataType type) { diff --git a/isthmus/src/test/resources/extensions/functions_custom.yaml b/isthmus/src/test/resources/extensions/functions_custom.yaml index 9fb8b010a..03160f723 100644 --- a/isthmus/src/test/resources/extensions/functions_custom.yaml +++ b/isthmus/src/test/resources/extensions/functions_custom.yaml @@ -1,5 +1,6 @@ %YAML 1.2 --- +urn: extension:substrait:functions_custom types: - name: "a_type" - name: "b_type" diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml index fb33385a1..48281ea03 100644 --- a/spark/src/main/resources/spark.yml +++ b/spark/src/main/resources/spark.yml @@ -14,6 +14,7 @@ # limitations under the License. %YAML 1.2 --- +urn: extension:substrait:spark scalar_functions: - name: add description: >- diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index 7bb28d5ed..02d67c4b5 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -26,10 +26,10 @@ import scala.collection.JavaConverters import scala.collection.JavaConverters.asScalaBufferConverter object SparkExtension { - final val uri = "/spark.yml" + final val file = "/spark.yml" private val SparkImpls: SimpleExtension.ExtensionCollection = - SimpleExtension.load(Collections.singletonList(uri)) + SimpleExtension.load(file, getClass.getResourceAsStream(file)) private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = DefaultExtensionCatalog.DEFAULT_COLLECTION; diff --git a/substrait b/substrait index 793c64ba2..4c3531872 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit 793c64ba26e337c22f5e91b658be58b1eea7efd3 +Subproject commit 4c35318727c36d6e49779c06daf9f4ced722fe43