From bc8fc8a47d00fd95ec7b47409bfa9fb8d2cdc268 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 17 Sep 2025 13:23:33 -0400 Subject: [PATCH 1/6] chore: bump substrait version to v0.75.0 --- substrait | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/substrait b/substrait index 793c64ba2..4c3531872 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit 793c64ba26e337c22f5e91b658be58b1eea7efd3 +Subproject commit 4c35318727c36d6e49779c06daf9f4ced722fe43 From 3ee953c67c92e8e648109aba9c04e2f5de36e9ea Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 22 Sep 2025 11:19:59 -0400 Subject: [PATCH 2/6] feat(core): uri -> urn + add urn validation This commit fully moves from URIs to URNs. This is meant as an intermediate commit in the graceful migration. Later work will re-add support for URIs in a way such that both URIs and URNs are accepted on parse and emitted in plans. It also adds necessary kotlin dep to make debugging work. BREAKING CHANGE: this commit alters the API by dropping support for URIs and adding support for URNs. --- .../io/substrait/dsl/SubstraitBuilder.java | 2 +- .../io/substrait/expression/Expression.java | 4 +- .../expression/ExpressionCreator.java | 4 +- .../proto/ExpressionProtoConverter.java | 2 +- .../proto/ProtoExpressionConverter.java | 2 +- .../extension/DefaultExtensionCatalog.java | 30 +-- .../extension/ExtensionCollector.java | 46 ++--- .../extension/ImmutableExtensionLookup.java | 32 ++-- .../substrait/extension/SimpleExtension.java | 177 +++++++++++------- .../java/io/substrait/type/Deserializers.java | 2 +- .../src/main/java/io/substrait/type/Type.java | 2 +- .../java/io/substrait/type/TypeCreator.java | 4 +- .../type/parser/TypeStringParser.java | 16 +- .../type/proto/BaseProtoConverter.java | 2 +- .../type/proto/ProtoTypeConverter.java | 2 +- .../extension/TypeExtensionTest.java | 4 +- .../extension/UrnValidationTest.java | 44 +++++ .../extensions/custom_extensions.yaml | 1 + .../isthmus/expression/CallConverters.java | 2 +- .../isthmus/expression/EnumConverter.java | 2 +- .../io/substrait/isthmus/CalciteTypeTest.java | 2 +- .../substrait/isthmus/CustomFunctionTest.java | 6 +- .../isthmus/RelCopyOnWriteVisitorTest.java | 5 +- .../isthmus/utils/UserTypeFactory.java | 8 +- .../extensions/functions_custom.yaml | 1 + spark/src/main/resources/spark.yml | 1 + .../io/substrait/spark/SparkExtension.scala | 4 +- 27 files changed, 247 insertions(+), 160 deletions(-) create mode 100644 core/src/test/java/io/substrait/extension/UrnValidationTest.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 4e8e428f7..ca6cd4ce3 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -699,7 +699,7 @@ public Expression.WindowFunctionInvocation windowFn( // Types public Type.UserDefined userDefinedType(String namespace, String typeName) { - return Type.UserDefined.builder().uri(namespace).name(typeName).nullable(false).build(); + return Type.UserDefined.builder().urn(namespace).name(typeName).nullable(false).build(); } // Misc diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index e5c45e19e..42c3c5118 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -666,13 +666,13 @@ public R accept( abstract class UserDefinedLiteral implements Literal { public abstract ByteString value(); - public abstract String uri(); + public abstract String urn(); public abstract String name(); @Override public Type getType() { - return Type.withNullability(nullable()).userDefined(uri(), name()); + return Type.withNullability(nullable()).userDefined(urn(), name()); } public static ImmutableExpression.UserDefinedLiteral.Builder builder() { diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 4a18026ab..adf157d7b 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -287,10 +287,10 @@ public static Expression.StructLiteral struct( } public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String uri, String name, Any value) { + boolean nullable, String urn, String name, Any value) { return Expression.UserDefinedLiteral.builder() .nullable(nullable) - .uri(uri) + .urn(urn) .name(name) .value(value.toByteString()) .build(); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 093e3fff3..caf145dfc 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -361,7 +361,7 @@ public Expression visit( public Expression visit( io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) { int typeReference = - extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return lit( bldr -> { try { diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 859ddfca5..8f95cdf07 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -495,7 +495,7 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue()); + literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 39d28298f..89aad954e 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -5,19 +5,23 @@ import java.util.stream.Collectors; public class DefaultExtensionCatalog { - public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml"; - public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml"; - public static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml"; - public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "/functions_arithmetic_decimal.yaml"; - public static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml"; - public static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml"; - public static final String FUNCTIONS_DATETIME = "/functions_datetime.yaml"; - public static final String FUNCTIONS_GEOMETRY = "/functions_geometry.yaml"; - public static final String FUNCTIONS_LOGARITHMIC = "/functions_logarithmic.yaml"; - public static final String FUNCTIONS_ROUNDING = "/functions_rounding.yaml"; - public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml"; - public static final String FUNCTIONS_SET = "/functions_set.yaml"; - public static final String FUNCTIONS_STRING = "/functions_string.yaml"; + public static final String FUNCTIONS_AGGREGATE_APPROX = + "extension:io.substrait:functions_aggregate_approx"; + public static final String FUNCTIONS_AGGREGATE_GENERIC = + "extension:io.substrait:functions_aggregate_generic"; + public static final String FUNCTIONS_ARITHMETIC = "extension:io.substrait:functions_arithmetic"; + public static final String FUNCTIONS_ARITHMETIC_DECIMAL = + "extension:io.substrait:functions_arithmetic_decimal"; + public static final String FUNCTIONS_BOOLEAN = "extension:io.substrait:functions_boolean"; + public static final String FUNCTIONS_COMPARISON = "extension:io.substrait:functions_comparison"; + public static final String FUNCTIONS_DATETIME = "extension:io.substrait:functions_datetime"; + public static final String FUNCTIONS_GEOMETRY = "extension:io.substrait:functions_geometry"; + public static final String FUNCTIONS_LOGARITHMIC = "extension:io.substrait:functions_logarithmic"; + public static final String FUNCTIONS_ROUNDING = "extension:io.substrait:functions_rounding"; + public static final String FUNCTIONS_ROUNDING_DECIMAL = + "extension:io.substrait:functions_rounding_decimal"; + public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = loadDefaultCollection(); diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index d408c600b..1829e2fbb 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -3,7 +3,7 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; -import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; @@ -52,30 +52,30 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensions(simpleExtensions.extensionList); } public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensions(simpleExtensions.extensionList); } private SimpleExtensions getExtensions() { - AtomicInteger uriPos = new AtomicInteger(1); - HashMap uris = new HashMap<>(); + AtomicInteger urnPos = new AtomicInteger(1); + HashMap urns = new HashMap<>(); ArrayList extensionList = new ArrayList<>(); for (Map.Entry e : funcMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), + SimpleExtensionURN urn = + urns.computeIfAbsent( + e.getValue().urn(), k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(urnPos.getAndIncrement()) + .setUrn(k) .build()); SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() @@ -83,18 +83,18 @@ private SimpleExtensions getExtensions() { SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(e.getKey()) .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) + .setExtensionUrnReference(urn.getExtensionUrnAnchor())) .build(); extensionList.add(decl); } for (Map.Entry e : typeMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), + SimpleExtensionURN urn = + urns.computeIfAbsent( + e.getValue().urn(), k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(urnPos.getAndIncrement()) + .setUrn(k) .build()); SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() @@ -102,21 +102,21 @@ private SimpleExtensions getExtensions() { SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(e.getKey()) .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) + .setExtensionUrnReference(urn.getExtensionUrnAnchor())) .build(); extensionList.add(decl); } - return new SimpleExtensions(uris, extensionList); + return new SimpleExtensions(urns, extensionList); } private static final class SimpleExtensions { - final HashMap uris; + final HashMap urns; final ArrayList extensionList; SimpleExtensions( - HashMap uris, + HashMap urns, ArrayList extensionList) { - this.uris = uris; + this.urns = urns; this.extensionList = extensionList; } } diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 6ac4fe922..9bcdeada9 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -3,7 +3,7 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; -import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -30,20 +30,21 @@ public static class Builder { private final Map typeMap = new HashMap<>(); public Builder from(Plan plan) { - return from(plan.getExtensionUrisList(), plan.getExtensionsList()); + return from(plan.getExtensionUrnsList(), plan.getExtensionsList()); } public Builder from(ExtendedExpression extendedExpression) { return from( - extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()); + extendedExpression.getExtensionUrnsList(), extendedExpression.getExtensionsList()); } private Builder from( - List simpleExtensionURIs, + List simpleExtensionURNs, List simpleExtensionDeclarations) { - Map namespaceMap = new HashMap<>(); - for (SimpleExtensionURI extension : simpleExtensionURIs) { - namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); + Map urnMap = new HashMap<>(); + // Handle URN format + for (SimpleExtensionURN extension : simpleExtensionURNs) { + urnMap.put(extension.getExtensionUrnAnchor(), extension.getUrn()); } // Add all functions used in plan to the functionMap @@ -53,13 +54,14 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); int reference = func.getFunctionAnchor(); - String namespace = namespaceMap.get(func.getExtensionUriReference()); - if (namespace == null) { + String urn = urnMap.get(func.getExtensionUrnReference()); + if (urn == null) { throw new IllegalStateException( - "Could not find extension URI of " + func.getExtensionUriReference()); + "Could not find extension URN for function reference " + + func.getExtensionUrnReference()); } String name = func.getName(); - SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name); + SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(urn, name); functionMap.put(reference, anchor); } @@ -70,13 +72,13 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); int reference = type.getTypeAnchor(); - String namespace = namespaceMap.get(type.getExtensionUriReference()); - if (namespace == null) { + String urn = urnMap.get(type.getExtensionUrnReference()); + if (urn == null) { throw new IllegalStateException( - "Could not find extension URI of " + type.getExtensionUriReference()); + "Could not find extension URN for type reference " + type.getExtensionUrnReference()); } String name = type.getName(); - SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name); + SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(urn, name); typeMap.put(reference, anchor); } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 14e11ca10..d60049ecc 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -1,10 +1,10 @@ package io.substrait.extension; import com.fasterxml.jackson.annotation.JacksonInject; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.InjectableValues; import com.fasterxml.jackson.databind.ObjectMapper; @@ -26,8 +26,11 @@ import java.util.Map; import java.util.Optional; import java.util.OptionalInt; +import java.util.Scanner; import java.util.Set; +import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -41,12 +44,25 @@ public class SimpleExtension { private static final Logger LOGGER = LoggerFactory.getLogger(SimpleExtension.class); - // Key for looking up URI in InjectableValues - public static final String URI_LOCATOR_KEY = "uri"; + // Key for looking up URN in InjectableValues + public static final String URN_LOCATOR_KEY = "urn"; - private static ObjectMapper objectMapper(String namespace) { + private static final Predicate URN_CHECKER = + Pattern.compile("^extension:[^:]+:[^:]+$").asPredicate(); + + private static void validateUrn(String urn) { + if (urn == null || urn.trim().isEmpty()) { + throw new IllegalArgumentException("URN cannot be null or empty"); + } + if (!URN_CHECKER.test(urn)) { + throw new IllegalArgumentException( + "URN must follow format 'extension::', got: " + urn); + } + } + + private static ObjectMapper objectMapper(String urn) { InjectableValues.Std iv = new InjectableValues.Std(); - iv.addValue(URI_LOCATOR_KEY, namespace); + iv.addValue(URN_LOCATOR_KEY, urn); return new ObjectMapper(new YAMLFactory()) .enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @@ -184,25 +200,22 @@ public static ImmutableSimpleExtension.EnumArgument.Builder builder() { } public interface Anchor { - String namespace(); + String urn(); String key(); } @Value.Immutable public interface FunctionAnchor extends Anchor { - static FunctionAnchor of(String namespace, String key) { - return ImmutableSimpleExtension.FunctionAnchor.builder() - .namespace(namespace) - .key(key) - .build(); + static FunctionAnchor of(String urn, String key) { + return ImmutableSimpleExtension.FunctionAnchor.builder().urn(urn).key(key).build(); } } @Value.Immutable public interface TypeAnchor extends Anchor { - static TypeAnchor of(String namespace, String name) { - return ImmutableSimpleExtension.TypeAnchor.builder().namespace(namespace).key(name).build(); + static TypeAnchor of(String urn, String name) { + return ImmutableSimpleExtension.TypeAnchor.builder().urn(urn).key(name).build(); } } @@ -226,7 +239,7 @@ default ParameterConsistency parameterConsistency() { public abstract static class Function { private final Supplier anchorSupplier = - Util.memoize(() -> FunctionAnchor.of(uri(), key())); + Util.memoize(() -> FunctionAnchor.of(urn(), key())); private final Supplier keySupplier = Util.memoize(() -> constructKey(name(), args())); private final Supplier> requiredArgsSupplier = Util.memoize( @@ -241,8 +254,8 @@ public String name() { } @Value.Default - public String uri() { - // we can't use null detection here since we initially construct this without a uri, then + public String urn() { + // we can't use null detection here since we initially construct this without a urn, then // resolve later. return ""; } @@ -366,8 +379,8 @@ public abstract static class ScalarFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -375,9 +388,9 @@ public Stream resolve(String uri) { @JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class) @Value.Immutable public abstract static class ScalarFunctionVariant extends Function { - public ScalarFunctionVariant resolve(String uri, String name, String description) { + public ScalarFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -402,8 +415,8 @@ public abstract static class AggregateFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -419,8 +432,8 @@ public abstract static class WindowFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } public static ImmutableSimpleExtension.WindowFunction.Builder builder() { @@ -446,9 +459,9 @@ public String toString() { @Nullable public abstract TypeExpression intermediate(); - AggregateFunctionVariant resolve(String uri, String name, String description) { + AggregateFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.AggregateFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -488,9 +501,9 @@ public String toString() { return super.toString(); } - WindowFunctionVariant resolve(String uri, String name, String description) { + WindowFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.WindowFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -515,12 +528,12 @@ public static ImmutableSimpleExtension.WindowFunctionVariant.Builder builder() { @Value.Immutable public abstract static class Type { private final Supplier anchorSupplier = - Util.memoize(() -> TypeAnchor.of(uri(), name())); + Util.memoize(() -> TypeAnchor.of(urn(), name())); public abstract String name(); - @JacksonInject(SimpleExtension.URI_LOCATOR_KEY) - public abstract String uri(); + @JacksonInject(SimpleExtension.URN_LOCATOR_KEY) + public abstract String urn(); // TODO: Handle conversion of structure object to Named Struct representation protected abstract Optional structure(); @@ -532,11 +545,15 @@ public TypeAnchor getAnchor() { @JsonDeserialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) @JsonSerialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) + @JsonIgnoreProperties(ignoreUnknown = true) @Value.Immutable public abstract static class ExtensionSignatures { @JsonProperty("types") public abstract List types(); + @JsonProperty("urn") + public abstract String urn(); + @JsonProperty("scalar_functions") public abstract List scalars(); @@ -553,27 +570,27 @@ public int size() { + (windows() == null ? 0 : windows().size()); } - public Stream resolve(String uri) { + public Stream resolve(String urn) { return Stream.concat( Stream.concat( - scalars() == null ? Stream.of() : scalars().stream().flatMap(f -> f.resolve(uri)), + scalars() == null ? Stream.of() : scalars().stream().flatMap(f -> f.resolve(urn)), aggregates() == null ? Stream.of() - : aggregates().stream().flatMap(f -> f.resolve(uri))), - windows() == null ? Stream.of() : windows().stream().flatMap(f -> f.resolve(uri))); + : aggregates().stream().flatMap(f -> f.resolve(urn))), + windows() == null ? Stream.of() : windows().stream().flatMap(f -> f.resolve(urn))); } } @Value.Immutable public abstract static class ExtensionCollection { - private final Supplier> namespaceSupplier = + private final Supplier> urnSupplier = Util.memoize( () -> { return Stream.concat( Stream.concat( - scalarFunctions().stream().map(Function::uri), - aggregateFunctions().stream().map(Function::uri)), - windowFunctions().stream().map(Function::uri)) + scalarFunctions().stream().map(Function::urn), + aggregateFunctions().stream().map(Function::urn)), + windowFunctions().stream().map(Function::urn)) .collect(Collectors.toSet()); }); @@ -627,11 +644,11 @@ public Type getType(TypeAnchor anchor) { if (type != null) { return type; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected type with name %s. The namespace %s is loaded but no type with this name found.", - anchor.key(), anchor.namespace())); + "Unexpected type with name %s. The URN %s is loaded but no type with this name found.", + anchor.key(), anchor.urn())); } public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { @@ -639,16 +656,16 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { if (variant != null) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected scalar function with key %s. The namespace %s is loaded " + "Unexpected scalar function with key %s. The URN %s is loaded " + "but no scalar function with this key found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); } - private void checkNamespace(String name) { - if (namespaceSupplier.get().contains(name)) { + private void checkUrn(String name) { + if (urnSupplier.get().contains(name)) { return; } @@ -665,12 +682,12 @@ public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected aggregate function with key %s. The namespace %s is loaded " + "Unexpected aggregate function with key %s. The URN %s is loaded " + "but no aggregate function with this key was found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); } public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { @@ -678,12 +695,12 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { if (variant != null) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected window aggregate function with key %s. The namespace %s is loaded " + "Unexpected window aggregate function with key %s. The URN %s is loaded " + "but no window aggregate function with this key was found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); } public ExtensionCollection merge(ExtensionCollection extensionCollection) { @@ -710,7 +727,7 @@ public static ExtensionCollection load(List resourcePaths) { .map( path -> { try (InputStream stream = ExtensionCollection.class.getResourceAsStream(path)) { - return load(path, stream); + return load(stream); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -723,41 +740,51 @@ public static ExtensionCollection load(List resourcePaths) { return complete; } - public static ExtensionCollection load(String namespace, String str) { + public static ExtensionCollection load(String content) { try { - ExtensionSignatures doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class); - return buildExtensionCollection(namespace, doc); - } catch (JsonProcessingException e) { + // Parse with basic YAML mapper first to extract URN (if present) + ObjectMapper basicYamlMapper = new ObjectMapper(new YAMLFactory()); + com.fasterxml.jackson.databind.JsonNode rootNode = basicYamlMapper.readTree(content); + + // URN is required + com.fasterxml.jackson.databind.JsonNode urnNode = rootNode.get("urn"); + if (urnNode == null) { + throw new IllegalArgumentException("Extension YAML file must contain a 'urn' field"); + } + String urn = urnNode.asText(); + validateUrn(urn); + + // Then parse with URN-aware mapper + ExtensionSignatures doc = objectMapper(urn).readValue(content, ExtensionSignatures.class); + return buildExtensionCollection(urn, doc); + } catch (IOException e) { throw new IllegalStateException(e); } } - public static ExtensionCollection load(String namespace, InputStream stream) { - try { - ExtensionSignatures doc = - objectMapper(namespace).readValue(stream, ExtensionSignatures.class); - return buildExtensionCollection(namespace, doc); - } catch (RuntimeException ex) { - throw ex; - } catch (Exception ex) { - throw new IllegalStateException("Failure while parsing " + namespace, ex); + public static ExtensionCollection load(InputStream stream) { + try (Scanner scanner = new Scanner(stream)) { + scanner.useDelimiter("\\A"); + String content = scanner.next(); + return load(content); } } public static ExtensionCollection buildExtensionCollection( - String namespace, ExtensionSignatures extensionSignatures) { + String urn, ExtensionSignatures extensionSignatures) { + validateUrn(urn); List scalarFunctionVariants = extensionSignatures.scalars().stream() - .flatMap(t -> t.resolve(namespace)) + .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); List aggregateFunctionVariants = extensionSignatures.aggregates().stream() - .flatMap(t -> t.resolve(namespace)) + .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); Stream windowFunctionVariants = - extensionSignatures.windows().stream().flatMap(t -> t.resolve(namespace)); + extensionSignatures.windows().stream().flatMap(t -> t.resolve(urn)); // Aggregate functions can be used as Window Functions Stream windowAggFunctionVariants = @@ -789,7 +816,13 @@ public static ExtensionCollection buildExtensionCollection( "Loaded {} aggregate functions and {} scalar functions from {}.", collection.aggregateFunctions().size(), collection.scalarFunctions().size(), - namespace); + extensionSignatures.urn()); return collection; } + + public static ExtensionCollection buildExtensionCollection( + ExtensionSignatures extensionSignatures) { + String urn = extensionSignatures.urn(); + return buildExtensionCollection(urn, extensionSignatures); + } } diff --git a/core/src/main/java/io/substrait/type/Deserializers.java b/core/src/main/java/io/substrait/type/Deserializers.java index 160004a3b..efdbc3306 100644 --- a/core/src/main/java/io/substrait/type/Deserializers.java +++ b/core/src/main/java/io/substrait/type/Deserializers.java @@ -44,7 +44,7 @@ public T deserialize(final JsonParser p, final DeserializationContext ctxt) String typeString = p.getValueAsString(); try { String namespace = - (String) ctxt.findInjectableValue(SimpleExtension.URI_LOCATOR_KEY, null, null); + (String) ctxt.findInjectableValue(SimpleExtension.URN_LOCATOR_KEY, null, null); return TypeStringParser.parse(typeString, namespace, converter); } catch (Exception ex) { throw JsonMappingException.from( diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 362adc3a8..aaf97aa12 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -389,7 +389,7 @@ public R accept(final TypeVisitor typeVisitor) th @Value.Immutable abstract class UserDefined implements Type { - public abstract String uri(); + public abstract String urn(); public abstract String name(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 880b72ed9..43358e505 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -108,8 +108,8 @@ public Type.Map map(Type key, Type value) { return Type.Map.builder().nullable(nullable).key(key).value(value).build(); } - public Type userDefined(String uri, String name) { - return Type.UserDefined.builder().nullable(nullable).uri(uri).name(name).build(); + public Type userDefined(String urn, String name) { + return Type.UserDefined.builder().nullable(nullable).urn(urn).name(name).build(); } public static TypeCreator of(boolean nullability) { diff --git a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java index 07093e925..8085ea6b8 100644 --- a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java +++ b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java @@ -16,16 +16,16 @@ public class TypeStringParser { private TypeStringParser() {} - public static Type parseSimple(String str, String namespace) { - return parse(str, namespace, ParseToPojo::type); + public static Type parseSimple(String str, String urn) { + return parse(str, urn, ParseToPojo::type); } - public static ParameterizedType parseParameterized(String str, String namespace) { - return parse(str, namespace, ParseToPojo::parameterizedType); + public static ParameterizedType parseParameterized(String str, String urn) { + return parse(str, urn, ParseToPojo::parameterizedType); } - public static TypeExpression parseExpression(String str, String namespace) { - return parse(str, namespace, ParseToPojo::typeExpression); + public static TypeExpression parseExpression(String str, String urn) { + return parse(str, urn, ParseToPojo::typeExpression); } private static SubstraitTypeParser.StartContext parse(String str) { @@ -40,8 +40,8 @@ private static SubstraitTypeParser.StartContext parse(String str) { } public static T parse( - String str, String namespace, BiFunction func) { - return func.apply(namespace, parse(str)); + String str, String urn, BiFunction func) { + return func.apply(urn, parse(str)); } public static TypeExpression parse(String str, ParseToPojo.Visitor visitor) { diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index a8c64db95..691d4bce5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -164,7 +164,7 @@ public final T visit(final Type.Map expr) { @Override public final T visit(final Type.UserDefined expr) { int ref = - extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return typeContainer(expr).userDefined(ref); } } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 661f57fea..95d42328a 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -90,7 +90,7 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.uri(), t.name()); + return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index cd2522090..e3bd9b7f4 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -28,13 +28,13 @@ public class TypeExtensionTest { static final TypeCreator R = TypeCreator.of(false); - static final String NAMESPACE = "/custom_extensions"; + static final String NAMESPACE = "extension:test:custom_extensions"; final SimpleExtension.ExtensionCollection extensionCollection; { InputStream inputStream = this.getClass().getResourceAsStream("/extensions/custom_extensions.yaml"); - extensionCollection = SimpleExtension.load(NAMESPACE, inputStream); + extensionCollection = SimpleExtension.load(inputStream); } final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); diff --git a/core/src/test/java/io/substrait/extension/UrnValidationTest.java b/core/src/test/java/io/substrait/extension/UrnValidationTest.java new file mode 100644 index 000000000..3d3354a6f --- /dev/null +++ b/core/src/test/java/io/substrait/extension/UrnValidationTest.java @@ -0,0 +1,44 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class UrnValidationTest { + + @Test + public void testMissingUrnThrowsException() { + String yamlWithoutUrn = "%YAML 1.2\n" + "---\n" + "scalar_functions:\n" + " - name: test\n"; + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load(yamlWithoutUrn)); + assertTrue(exception.getMessage().contains("Extension YAML file must contain a 'urn' field")); + } + + @Test + public void testInvalidUrnFormatThrowsException() { + String yamlWithInvalidUrn = + "%YAML 1.2\n" + + "---\n" + + "urn: invalid:format\n" + + "scalar_functions:\n" + + " - name: test\n"; + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> SimpleExtension.load(yamlWithInvalidUrn)); + assertTrue( + exception.getMessage().contains("URN must follow format 'extension::'")); + } + + @Test + public void testValidUrnWorks() { + String yamlWithValidUrn = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:valid\n" + + "scalar_functions:\n" + + " - name: test\n"; + assertDoesNotThrow(() -> SimpleExtension.load(yamlWithValidUrn)); + } +} diff --git a/core/src/test/resources/extensions/custom_extensions.yaml b/core/src/test/resources/extensions/custom_extensions.yaml index 204a5f9ac..4776312ac 100644 --- a/core/src/test/resources/extensions/custom_extensions.yaml +++ b/core/src/test/resources/extensions/custom_extensions.yaml @@ -1,5 +1,6 @@ %YAML 1.2 --- +urn: extension:test:custom_extensions types: - name: "customType1" - name: "customType2" diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index c0e1796da..3406de7de 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -71,7 +71,7 @@ public class CallConverters { Type.UserDefined t = (Type.UserDefined) type; return Expression.UserDefinedLiteral.builder() - .uri(t.uri()) + .urn(t.urn()) .name(t.name()) .value(literal.value()) .build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index cec68fcb7..b69ef9b02 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -185,7 +185,7 @@ private static ArgAnchor argAnchor(String fnNS, String fnSig, int argIdx) { private static ArgAnchor argAnchor(SimpleExtension.Function fnDef, int argIdx) { return new ArgAnchor( - SimpleExtension.FunctionAnchor.of(fnDef.getAnchor().namespace(), fnDef.getAnchor().key()), + SimpleExtension.FunctionAnchor.of(fnDef.getAnchor().urn(), fnDef.getAnchor().key()), argIdx); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java index a12a1eac6..176552f7d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java @@ -37,7 +37,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override public RelDataType toCalcite(Type.UserDefined type) { - if (type.uri().equals(uTypeURI) && type.name().equals(uTypeName)) { + if (type.urn().equals(uTypeURI) && type.name().equals(uTypeName)) { return uTypeFactory.createCalcite(type.nullable()); } return null; diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index fabddf56e..7e59a82a2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -43,7 +43,7 @@ public class CustomFunctionTest extends PlanTestBase { // Define custom functions in a "functions_custom.yaml" extension - static final String NAMESPACE = "/functions_custom"; + static final String NAMESPACE = "extension:substrait:functions_custom"; static final String FUNCTIONS_CUSTOM; static { @@ -56,7 +56,7 @@ public class CustomFunctionTest extends PlanTestBase { // Load custom extension into an ExtensionCollection static final SimpleExtension.ExtensionCollection extensionCollection = - SimpleExtension.load("/functions_custom", FUNCTIONS_CUSTOM); + SimpleExtension.load(FUNCTIONS_CUSTOM); final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); @@ -84,7 +84,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override public RelDataType toCalcite(Type.UserDefined type) { - if (type.uri().equals(NAMESPACE)) { + if (type.urn().equals(NAMESPACE)) { if (type.name().equals(aTypeName)) { return aTypeFactory.createCalcite(type.nullable()); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index ceb01e9e3..f6ca3b150 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -29,9 +29,10 @@ public class RelCopyOnWriteVisitorTest extends PlanTestBase { public static SimpleExtension.FunctionAnchor APPROX_COUNT_DISTINCT = SimpleExtension.FunctionAnchor.of( - "/functions_aggregate_approx.yaml", "approx_count_distinct:any"); + "extension:io.substrait:functions_aggregate_approx", "approx_count_distinct:any"); public static SimpleExtension.FunctionAnchor COUNT = - SimpleExtension.FunctionAnchor.of("/functions_aggregate_generic.yaml", "count:any"); + SimpleExtension.FunctionAnchor.of( + "extension:io.substrait:functions_aggregate_generic", "count:any"); private static final String COUNT_DISTINCT_SUBBQUERY = "select\n" diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 99fbc7d09..2c90f133d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -13,11 +13,11 @@ public class UserTypeFactory { private final InnerType N; private final InnerType R; - private final String uri; + private final String urn; private final String name; - public UserTypeFactory(String uri, String name) { - this.uri = uri; + public UserTypeFactory(String urn, String name) { + this.urn = urn; this.name = name; this.N = new InnerType(true, name); this.R = new InnerType(false, name); @@ -32,7 +32,7 @@ public RelDataType createCalcite(boolean nullable) { } public Type createSubstrait(boolean nullable) { - return TypeCreator.of(nullable).userDefined(uri, name); + return TypeCreator.of(nullable).userDefined(urn, name); } public boolean isTypeFromFactory(RelDataType type) { diff --git a/isthmus/src/test/resources/extensions/functions_custom.yaml b/isthmus/src/test/resources/extensions/functions_custom.yaml index 9fb8b010a..03160f723 100644 --- a/isthmus/src/test/resources/extensions/functions_custom.yaml +++ b/isthmus/src/test/resources/extensions/functions_custom.yaml @@ -1,5 +1,6 @@ %YAML 1.2 --- +urn: extension:substrait:functions_custom types: - name: "a_type" - name: "b_type" diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml index fb33385a1..48281ea03 100644 --- a/spark/src/main/resources/spark.yml +++ b/spark/src/main/resources/spark.yml @@ -14,6 +14,7 @@ # limitations under the License. %YAML 1.2 --- +urn: extension:substrait:spark scalar_functions: - name: add description: >- diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index 7bb28d5ed..595d5c169 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -26,10 +26,10 @@ import scala.collection.JavaConverters import scala.collection.JavaConverters.asScalaBufferConverter object SparkExtension { - final val uri = "/spark.yml" + final val file = "/spark.yml" private val SparkImpls: SimpleExtension.ExtensionCollection = - SimpleExtension.load(Collections.singletonList(uri)) + SimpleExtension.load(getClass.getResourceAsStream(file)) private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = DefaultExtensionCatalog.DEFAULT_COLLECTION; From c7639b3743e1e42d10239fbeaf17dbfc469576f0 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 24 Sep 2025 16:01:14 -0400 Subject: [PATCH 3/6] refactor: move BidiMap to another file to be shared Also introduces a put method to ensure no duplicates, as well as a merge method. --- .../java/io/substrait/extension/BidiMap.java | 65 +++++++++++++++++++ .../extension/ExtensionCollector.java | 28 +------- 2 files changed, 67 insertions(+), 26 deletions(-) create mode 100644 core/src/main/java/io/substrait/extension/BidiMap.java diff --git a/core/src/main/java/io/substrait/extension/BidiMap.java b/core/src/main/java/io/substrait/extension/BidiMap.java new file mode 100644 index 000000000..b9ec22135 --- /dev/null +++ b/core/src/main/java/io/substrait/extension/BidiMap.java @@ -0,0 +1,65 @@ +package io.substrait.extension; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** We don't depend on guava... */ +class BidiMap { + private final Map forwardMap; + private final Map reverseMap; + + BidiMap(Map forwardMap) { + this.forwardMap = forwardMap; + this.reverseMap = new HashMap<>(); + for (Map.Entry entry : forwardMap.entrySet()) { + reverseMap.put(entry.getValue(), entry.getKey()); + } + } + + BidiMap() { + this.forwardMap = new HashMap<>(); + this.reverseMap = new HashMap<>(); + } + + T2 get(T1 t1) { + return forwardMap.get(t1); + } + + T1 reverseGet(T2 t2) { + return reverseMap.get(t2); + } + + /** + * Associates the specified values in both directions. Throws if either value is already mapped to + * a different value. + */ + void put(T1 t1, T2 t2) { + T2 existingForward = forwardMap.get(t1); + T1 existingReverse = reverseMap.get(t2); + + if (existingForward != null && !existingForward.equals(t2)) { + throw new IllegalArgumentException("Key already exists in map with different value"); + } + if (existingReverse != null && !existingReverse.equals(t1)) { + throw new IllegalArgumentException("Key already exists in map with different value"); + } + + forwardMap.put(t1, t2); + reverseMap.put(t2, t1); + } + + void merge(BidiMap other) { + for (Map.Entry entry : other.forwardEntrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + Set> forwardEntrySet() { + return forwardMap.entrySet(); + } + + Set> reverseEntrySet() { + return reverseMap.entrySet(); + } +} diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 1829e2fbb..b871a7bc0 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -68,7 +68,7 @@ private SimpleExtensions getExtensions() { HashMap urns = new HashMap<>(); ArrayList extensionList = new ArrayList<>(); - for (Map.Entry e : funcMap.forwardMap.entrySet()) { + for (Map.Entry e : funcMap.forwardEntrySet()) { SimpleExtensionURN urn = urns.computeIfAbsent( e.getValue().urn(), @@ -87,7 +87,7 @@ private SimpleExtensions getExtensions() { .build(); extensionList.add(decl); } - for (Map.Entry e : typeMap.forwardMap.entrySet()) { + for (Map.Entry e : typeMap.forwardEntrySet()) { SimpleExtensionURN urn = urns.computeIfAbsent( e.getValue().urn(), @@ -120,28 +120,4 @@ private static final class SimpleExtensions { this.extensionList = extensionList; } } - - /** We don't depend on guava... */ - private static class BidiMap { - private final Map forwardMap; - private final Map reverseMap; - - public BidiMap(Map forwardMap) { - this.forwardMap = forwardMap; - this.reverseMap = new HashMap<>(); - } - - public T2 get(T1 t1) { - return forwardMap.get(t1); - } - - public T1 reverseGet(T2 t2) { - return reverseMap.get(t2); - } - - public void put(T1 t1, T2 t2) { - forwardMap.put(t1, t2); - reverseMap.put(t2, t1); - } - } } From 7fd825a1156ad01cd57b4f3b863a9f68199159fe Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 25 Sep 2025 12:09:26 -0400 Subject: [PATCH 4/6] feat(core): graceful uri <-> urn handling Implement comprehensive fallback strategy to resolve extensions from either URN or legacy URI references with conflict detection for protobuf ambiguity. Introduce round-trip tests to ensure output always has uri + urn regardless of combination of uri or urn in input plans. BREAKING CHANGE: This PR alters the extension loading API to require both URI explicitly, and URN implicitly in the plan. Extensions which lack URN will throw an error, and parsed plans which contain a uri/urn without a matching urn/uri pre-loaded into the context will throw an error. --- core/build.gradle.kts | 2 + .../ProtoExtendedExpressionConverter.java | 5 +- .../extension/AbstractExtensionLookup.java | 20 + .../java/io/substrait/extension/BidiMap.java | 2 +- .../extension/ExtensionCollector.java | 106 ++- .../extension/ImmutableExtensionLookup.java | 176 ++++- .../substrait/extension/SimpleExtension.java | 83 ++- .../io/substrait/plan/PlanProtoConverter.java | 16 +- .../io/substrait/plan/ProtoPlanConverter.java | 6 +- .../ExtensionCollectionMergeTest.java | 87 +++ .../ExtensionCollectionUriUrnTest.java | 60 ++ .../ExtensionCollectorUriUrnTest.java | 41 ++ .../ImmutableExtensionLookupUriUrnTest.java | 621 ++++++++++++++++++ .../extension/TypeExtensionTest.java | 6 +- .../UriUrnMigrationEndToEndTest.java | 110 ++++ .../extension/UrnValidationTest.java | 22 +- .../complex-expected-plan.json | 162 +++++ .../uri-urn-migration/complex-input-plan.json | 148 +++++ .../mixed-partial-coverage-expected-plan.json | 133 ++++ .../mixed-partial-coverage-input-plan.json | 110 ++++ .../unresolvable-uri-plan.json | 73 ++ .../uri-only-expected-plan.json | 98 +++ .../uri-only-input-plan.json | 80 +++ .../urn-only-expected-plan.json | 133 ++++ .../urn-only-input-plan.json | 110 ++++ .../zero-urn-resolution-expected-plan.json | 104 +++ .../zero-urn-resolution-input-plan.json | 85 +++ .../substrait/isthmus/CustomFunctionTest.java | 2 +- .../isthmus/RelCopyOnWriteVisitorTest.java | 2 +- .../io/substrait/spark/SparkExtension.scala | 2 +- 30 files changed, 2539 insertions(+), 66 deletions(-) create mode 100644 core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java create mode 100644 core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java create mode 100644 core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java create mode 100644 core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java create mode 100644 core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java create mode 100644 core/src/test/resources/uri-urn-migration/complex-expected-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/complex-input-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/mixed-partial-coverage-expected-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/mixed-partial-coverage-input-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/unresolvable-uri-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/uri-only-expected-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/uri-only-input-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/urn-only-expected-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/urn-only-input-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/zero-urn-resolution-expected-plan.json create mode 100644 core/src/test/resources/uri-urn-migration/zero-urn-resolution-input-plan.json diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 09706f115..97b9dfbf6 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -108,6 +108,8 @@ configurations[JavaPlugin.TEST_IMPLEMENTATION_CONFIGURATION_NAME].extendsFrom(sh dependencies { testImplementation(platform(libs.junit.bom)) + testImplementation(libs.protobuf.java.util) + testImplementation(libs.junit.jupiter) testRuntimeOnly(libs.junit.platform.launcher) diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index c658fcbce..175b2d705 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -28,6 +28,9 @@ public ProtoExtendedExpressionConverter() { } public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } this.extensionCollection = extensionCollection; } @@ -35,7 +38,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp // fill in simple extension information through a discovery in the current proto-extended // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder().from(extendedExpression).build(); + ImmutableExtensionLookup.builder(extensionCollection).from(extendedExpression).build(); NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); diff --git a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java index 16e41f03f..4693a2674 100644 --- a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java @@ -13,6 +13,26 @@ public AbstractExtensionLookup( this.typeAnchorMap = typeAnchorMap; } + /** + * Gets the function anchor for a given reference (primarily for testing). + * + * @param reference The function reference + * @return The function anchor, or null if not found + */ + public SimpleExtension.FunctionAnchor getFunctionAnchor(int reference) { + return functionAnchorMap.get(reference); + } + + /** + * Gets the type anchor for a given reference (primarily for testing). + * + * @param reference The type reference + * @return The type anchor, or null if not found + */ + public SimpleExtension.TypeAnchor getTypeAnchor(int reference) { + return typeAnchorMap.get(reference); + } + @Override public SimpleExtension.ScalarFunctionVariant getScalarFunction( int reference, SimpleExtension.ExtensionCollection extensions) { diff --git a/core/src/main/java/io/substrait/extension/BidiMap.java b/core/src/main/java/io/substrait/extension/BidiMap.java index b9ec22135..f0eeb30a7 100644 --- a/core/src/main/java/io/substrait/extension/BidiMap.java +++ b/core/src/main/java/io/substrait/extension/BidiMap.java @@ -5,7 +5,7 @@ import java.util.Set; /** We don't depend on guava... */ -class BidiMap { +public class BidiMap { private final Map forwardMap; private final Map reverseMap; diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index b871a7bc0..ec313d860 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -3,6 +3,7 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; import io.substrait.proto.SimpleExtensionURN; import java.util.ArrayList; import java.util.HashMap; @@ -19,14 +20,27 @@ public class ExtensionCollector extends AbstractExtensionLookup { private final BidiMap funcMap; private final BidiMap typeMap; + private final SimpleExtension.ExtensionCollection extensionCollection; // start at 0 to make sure functionAnchors start with 1 according to spec private int counter = 0; + private String getUriFromUrn(String urn) { + return extensionCollection.getUriFromUrn(urn); + } + public ExtensionCollector() { + this(SimpleExtension.loadDefaults()); + } + + public ExtensionCollector(SimpleExtension.ExtensionCollection extensionCollection) { super(new HashMap<>(), new HashMap<>()); + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } funcMap = new BidiMap<>(functionAnchorMap); typeMap = new BidiMap<>(typeAnchorMap); + this.extensionCollection = extensionCollection; } public int getFunctionReference(SimpleExtension.Function declaration) { @@ -53,6 +67,7 @@ public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); builder.addAllExtensionUrns(simpleExtensions.urns.values()); + builder.addAllExtensionUris(simpleExtensions.uris.values()); builder.addAllExtensions(simpleExtensions.extensionList); } @@ -60,63 +75,116 @@ public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder SimpleExtensions simpleExtensions = getExtensions(); builder.addAllExtensionUrns(simpleExtensions.urns.values()); + builder.addAllExtensionUris(simpleExtensions.uris.values()); builder.addAllExtensions(simpleExtensions.extensionList); } private SimpleExtensions getExtensions() { AtomicInteger urnPos = new AtomicInteger(1); + AtomicInteger uriPos = new AtomicInteger(1); HashMap urns = new HashMap<>(); + HashMap uris = new HashMap<>(); ArrayList extensionList = new ArrayList<>(); for (Map.Entry e : funcMap.forwardEntrySet()) { - SimpleExtensionURN urn = + String urn = e.getValue().urn(); + String uri = getUriFromUrn(urn); + + // Create URN entry + SimpleExtensionURN urnObj = urns.computeIfAbsent( - e.getValue().urn(), + urn, k -> SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(urnPos.getAndIncrement()) .setUrn(k) .build()); + + // Create URI entry if mapping exists + SimpleExtensionURI uriObj = null; + if (uri != null) { + uriObj = + uris.computeIfAbsent( + uri, + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + } + + // Create function declaration with both URN and URI references + SimpleExtensionDeclaration.ExtensionFunction.Builder funcBuilder = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUrnReference(urnObj.getExtensionUrnAnchor()); + + if (uriObj != null) { + funcBuilder.setExtensionUriReference(uriObj.getExtensionUriAnchor()); + } + SimpleExtensionDeclaration decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUrnReference(urn.getExtensionUrnAnchor())) - .build(); + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(funcBuilder).build(); extensionList.add(decl); } + for (Map.Entry e : typeMap.forwardEntrySet()) { - SimpleExtensionURN urn = + String urn = e.getValue().urn(); + String uri = getUriFromUrn(urn); + + // Create URN entry + SimpleExtensionURN urnObj = urns.computeIfAbsent( - e.getValue().urn(), + urn, k -> SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(urnPos.getAndIncrement()) .setUrn(k) .build()); + + // Create URI entry if mapping exists + SimpleExtensionURI uriObj = null; + if (uri != null) { + uriObj = + uris.computeIfAbsent( + uri, + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + } + + // Create type declaration with both URN and URI references + SimpleExtensionDeclaration.ExtensionType.Builder typeBuilder = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUrnReference(urnObj.getExtensionUrnAnchor()); + + if (uriObj != null) { + typeBuilder.setExtensionUriReference(uriObj.getExtensionUriAnchor()); + } + SimpleExtensionDeclaration decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionType( - SimpleExtensionDeclaration.ExtensionType.newBuilder() - .setTypeAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUrnReference(urn.getExtensionUrnAnchor())) - .build(); + SimpleExtensionDeclaration.newBuilder().setExtensionType(typeBuilder).build(); extensionList.add(decl); } - return new SimpleExtensions(urns, extensionList); + return new SimpleExtensions(urns, uris, extensionList); } private static final class SimpleExtensions { final HashMap urns; + final HashMap uris; final ArrayList extensionList; SimpleExtensions( HashMap urns, + HashMap uris, ArrayList extensionList) { this.urns = urns; + this.uris = uris; this.extensionList = extensionList; } } diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 9bcdeada9..a8a365dbe 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -25,28 +25,183 @@ public static Builder builder() { return new Builder(); } + public static Builder builder(SimpleExtension.ExtensionCollection extensionCollection) { + return new Builder(extensionCollection); + } + public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); + private final SimpleExtension.ExtensionCollection extensionCollection; + + public Builder() { + this.extensionCollection = SimpleExtension.loadDefaults(); + } + + public Builder(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } + this.extensionCollection = extensionCollection; + } + + /** + * Resolves URN from URI using the URI/URN mapping. + * + * @param uri The URI to resolve + * @return The corresponding URN, or null if no mapping exists + */ + private String resolveUrnFromUri(String uri) { + return extensionCollection.getUrnFromUri(uri); + } + + private SimpleExtension.FunctionAnchor resolveFunctionAnchor( + SimpleExtensionDeclaration.ExtensionFunction func, + Map urnMap, + Map uriMap) { + + // 1. Try non-zero URN reference + if (func.getExtensionUrnReference() != 0) { + String urnFromUrnRef = urnMap.get(func.getExtensionUrnReference()); + if (urnFromUrnRef != null) { + return SimpleExtension.FunctionAnchor.of(urnFromUrnRef, func.getName()); + } + } + + // 2. Try non-zero URI reference + if (func.getExtensionUriReference() != 0) { + String uriFromUriRef = uriMap.get(func.getExtensionUriReference()); + if (uriFromUriRef != null) { + String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + if (urnFromUriRef != null) { + return SimpleExtension.FunctionAnchor.of(urnFromUriRef, func.getName()); + } + // URI found but could not be resolved to URN + } + } + + /* At this point both URI and URN are 0, so we need to + first see if they both resolve. + */ + + String urn = urnMap.get(func.getExtensionUrnReference()); + String uri = uriMap.get(func.getExtensionUriReference()); + + // 3. Try both 0 URI and 0 URN if both resolve + if (uri != null && urn != null) { + String resolvedUrn = resolveUrnFromUri(uri); + if (urn.equals(resolvedUrn)) { + return SimpleExtension.FunctionAnchor.of(urn, func.getName()); + } + throw new IllegalStateException( + String.format( + "Conflicting URI/URN mapping at reference 0: URI '%s' maps to URN '%s', but reference 0 also specifies URN '%s'. " + + "These must be consistent for proper resolution.", + uri, resolvedUrn, urn)); + } + + // 4. Try only 0 URN + if (urn != null) { + return SimpleExtension.FunctionAnchor.of(urn, func.getName()); + } + // 5. Try only 0 URI + if (uri != null && resolveUrnFromUri(uri) != null) { + return SimpleExtension.FunctionAnchor.of(resolveUrnFromUri(uri), func.getName()); + } + throw new IllegalStateException( + String.format( + "All resolution strategies failed for URI %s and URN %s (perhaps a URI <-> URN mapping was not registered during the migration) ", + uri, urn)); + } + + private SimpleExtension.TypeAnchor resolveTypeAnchor( + SimpleExtensionDeclaration.ExtensionType type, + Map urnMap, + Map uriMap) { + + // 1. Try non-zero URN reference + if (type.getExtensionUrnReference() != 0) { + String urnFromUrnRef = urnMap.get(type.getExtensionUrnReference()); + if (urnFromUrnRef != null) { + return SimpleExtension.TypeAnchor.of(urnFromUrnRef, type.getName()); + } + } + + // 2. Try non-zero URI reference + if (type.getExtensionUriReference() != 0) { + String uriFromUriRef = uriMap.get(type.getExtensionUriReference()); + if (uriFromUriRef != null) { + String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + if (urnFromUriRef != null) { + return SimpleExtension.TypeAnchor.of(urnFromUriRef, type.getName()); + } + // URI found but could not be resolved to URN + } + } + + /* At this point both URI and URN are 0, so we need to + first see if they both resolve. + */ + + String urn = urnMap.get(type.getExtensionUrnReference()); + String uri = uriMap.get(type.getExtensionUriReference()); + + // 3. Try both 0 URI and 0 URN if both resolve + if (uri != null && urn != null) { + String resolvedUrn = resolveUrnFromUri(uri); + if (urn.equals(resolvedUrn)) { + return SimpleExtension.TypeAnchor.of(urn, type.getName()); + } + throw new IllegalStateException( + String.format( + "Conflicting URI/URN mapping at reference 0: URI '%s' maps to URN '%s', but reference 0 also specifies URN '%s'. " + + "These must be consistent for proper resolution.", + uri, resolvedUrn, urn)); + } + + // 4. Try only 0 URN + if (urn != null) { + return SimpleExtension.TypeAnchor.of(urn, type.getName()); + } + // 5. Try only 0 URI + if (uri != null && resolveUrnFromUri(uri) != null) { + return SimpleExtension.TypeAnchor.of(resolveUrnFromUri(uri), type.getName()); + } + throw new IllegalStateException( + String.format( + "All resolution strategies failed for URI %s and URN %s (perhaps a URI <-> URN mapping was not registered during the migration) ", + uri, urn)); + } public Builder from(Plan plan) { - return from(plan.getExtensionUrnsList(), plan.getExtensionsList()); + return from( + plan.getExtensionUrnsList(), plan.getExtensionUrisList(), plan.getExtensionsList()); } public Builder from(ExtendedExpression extendedExpression) { return from( - extendedExpression.getExtensionUrnsList(), extendedExpression.getExtensionsList()); + extendedExpression.getExtensionUrnsList(), + extendedExpression.getExtensionUrisList(), + extendedExpression.getExtensionsList()); } private Builder from( List simpleExtensionURNs, + List simpleExtensionURIs, List simpleExtensionDeclarations) { Map urnMap = new HashMap<>(); + Map uriMap = new HashMap<>(); + // Handle URN format for (SimpleExtensionURN extension : simpleExtensionURNs) { urnMap.put(extension.getExtensionUrnAnchor(), extension.getUrn()); } + // Handle deprecated URI format + for (io.substrait.proto.SimpleExtensionURI extension : simpleExtensionURIs) { + uriMap.put(extension.getExtensionUriAnchor(), extension.getUri()); + } + // Add all functions used in plan to the functionMap for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { @@ -54,14 +209,7 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); int reference = func.getFunctionAnchor(); - String urn = urnMap.get(func.getExtensionUrnReference()); - if (urn == null) { - throw new IllegalStateException( - "Could not find extension URN for function reference " - + func.getExtensionUrnReference()); - } - String name = func.getName(); - SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(urn, name); + SimpleExtension.FunctionAnchor anchor = resolveFunctionAnchor(func, urnMap, uriMap); functionMap.put(reference, anchor); } @@ -72,13 +220,7 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); int reference = type.getTypeAnchor(); - String urn = urnMap.get(type.getExtensionUrnReference()); - if (urn == null) { - throw new IllegalStateException( - "Could not find extension URN for type reference " + type.getExtensionUrnReference()); - } - String name = type.getName(); - SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(urn, name); + SimpleExtension.TypeAnchor anchor = resolveTypeAnchor(type, urnMap, uriMap); typeMap.put(reference, anchor); } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index d60049ecc..ed45c4f5c 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -554,6 +554,13 @@ public abstract static class ExtensionSignatures { @JsonProperty("urn") public abstract String urn(); + // URI is not from YAML, but from the loading context + // this only needs to be present temporarily to handle the URI -> URN migration + @Value.Default + public String uri() { + return ""; + } + @JsonProperty("scalar_functions") public abstract List scalars(); @@ -627,6 +634,11 @@ public abstract static class ExtensionCollection { Function::getAnchor, java.util.function.Function.identity())); }); + @Value.Default + BidiMap uriUrnMap() { + return new BidiMap<>(); + } + public abstract List types(); public abstract List scalarFunctions(); @@ -703,7 +715,31 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { anchor.key(), anchor.urn())); } + /** + * Gets the URI for a given URN. This is for internal framework use during URI/URN migration. + * + * @param urn The URN to look up + * @return The corresponding URI, or null if not found + */ + public String getUriFromUrn(String urn) { + return uriUrnMap().reverseGet(urn); + } + + /** + * Gets the URN for a given URI. This is for internal framework use during URI/URN migration. + * + * @param uri The URI to look up + * @return The corresponding URN, or null if not found + */ + public String getUrnFromUri(String uri) { + return uriUrnMap().get(uri); + } + public ExtensionCollection merge(ExtensionCollection extensionCollection) { + BidiMap mergedUriUrnMap = new BidiMap<>(); + mergedUriUrnMap.merge(uriUrnMap()); + mergedUriUrnMap.merge(extensionCollection.uriUrnMap()); + return ImmutableSimpleExtension.ExtensionCollection.builder() .addAllAggregateFunctions(aggregateFunctions()) .addAllAggregateFunctions(extensionCollection.aggregateFunctions()) @@ -713,6 +749,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { .addAllWindowFunctions(extensionCollection.windowFunctions()) .addAllTypes(types()) .addAllTypes(extensionCollection.types()) + .uriUrnMap(mergedUriUrnMap) .build(); } } @@ -727,7 +764,7 @@ public static ExtensionCollection load(List resourcePaths) { .map( path -> { try (InputStream stream = ExtensionCollection.class.getResourceAsStream(path)) { - return load(stream); + return load(path, stream); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -740,13 +777,15 @@ public static ExtensionCollection load(List resourcePaths) { return complete; } - public static ExtensionCollection load(String content) { + public static ExtensionCollection load(String uri, String content) { try { - // Parse with basic YAML mapper first to extract URN (if present) + if (uri == null || uri.isEmpty()) { + throw new IllegalArgumentException("URI cannot be null or empty"); + } + + // Parse with basic YAML mapper first to extract URN ObjectMapper basicYamlMapper = new ObjectMapper(new YAMLFactory()); com.fasterxml.jackson.databind.JsonNode rootNode = basicYamlMapper.readTree(content); - - // URN is required com.fasterxml.jackson.databind.JsonNode urnNode = rootNode.get("urn"); if (urnNode == null) { throw new IllegalArgumentException("Extension YAML file must contain a 'urn' field"); @@ -754,25 +793,36 @@ public static ExtensionCollection load(String content) { String urn = urnNode.asText(); validateUrn(urn); - // Then parse with URN-aware mapper - ExtensionSignatures doc = objectMapper(urn).readValue(content, ExtensionSignatures.class); - return buildExtensionCollection(urn, doc); + ExtensionSignatures docWithoutUri = + objectMapper(urn).readValue(content, ExtensionSignatures.class); + + ExtensionSignatures doc = + ImmutableSimpleExtension.ExtensionSignatures.builder() + .from(docWithoutUri) + .uri(uri) + .build(); + + return buildExtensionCollection(uri, doc); } catch (IOException e) { throw new IllegalStateException(e); } } - public static ExtensionCollection load(InputStream stream) { + public static ExtensionCollection load(String uri, InputStream stream) { try (Scanner scanner = new Scanner(stream)) { scanner.useDelimiter("\\A"); String content = scanner.next(); - return load(content); + return load(uri, content); } } public static ExtensionCollection buildExtensionCollection( - String urn, ExtensionSignatures extensionSignatures) { + String uri, ExtensionSignatures extensionSignatures) { + String urn = extensionSignatures.urn(); validateUrn(urn); + if (uri == null || uri == "") { + throw new IllegalArgumentException("URI cannot be null or empty"); + } List scalarFunctionVariants = extensionSignatures.scalars().stream() .flatMap(t -> t.resolve(urn)) @@ -805,13 +855,18 @@ public static ExtensionCollection buildExtensionCollection( Stream.concat(windowFunctionVariants, windowAggFunctionVariants) .collect(Collectors.toList()); + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put(uri, urn); + ImmutableSimpleExtension.ExtensionCollection collection = ImmutableSimpleExtension.ExtensionCollection.builder() .scalarFunctions(scalarFunctionVariants) .aggregateFunctions(aggregateFunctionVariants) .windowFunctions(allWindowFunctionVariants) .addAllTypes(extensionSignatures.types()) + .uriUrnMap(uriUrnMap) .build(); + LOGGER.atDebug().log( "Loaded {} aggregate functions and {} scalar functions from {}.", collection.aggregateFunctions().size(), @@ -819,10 +874,4 @@ public static ExtensionCollection buildExtensionCollection( extensionSignatures.urn()); return collection; } - - public static ExtensionCollection buildExtensionCollection( - ExtensionSignatures extensionSignatures) { - String urn = extensionSignatures.urn(); - return buildExtensionCollection(urn, extensionSignatures); - } } diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index 5b8e2599e..96a3333c8 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -1,6 +1,7 @@ package io.substrait.plan; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; import io.substrait.proto.Rel; @@ -12,9 +13,22 @@ /** Converts from {@link io.substrait.plan.Plan} to {@link io.substrait.proto.Plan} */ public class PlanProtoConverter { + private final SimpleExtension.ExtensionCollection extensionCollection; + + public PlanProtoConverter() { + this(SimpleExtension.loadDefaults()); + } + + public PlanProtoConverter(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } + this.extensionCollection = extensionCollection; + } + public Plan toProto(io.substrait.plan.Plan plan) { List planRels = new ArrayList<>(); - ExtensionCollector functionCollector = new ExtensionCollector(); + ExtensionCollector functionCollector = new ExtensionCollector(extensionCollection); for (io.substrait.plan.Plan.Root root : plan.getRoots()) { Rel input = new RelProtoConverter(functionCollector).toProto(root.getInput()); planRels.add( diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 15145fe5c..596398c5b 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -21,6 +21,9 @@ public ProtoPlanConverter() { } public ProtoPlanConverter(SimpleExtension.ExtensionCollection extensionCollection) { + if (extensionCollection == null) { + throw new IllegalArgumentException("ExtensionCollection is required"); + } this.extensionCollection = extensionCollection; } @@ -30,7 +33,8 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java new file mode 100644 index 000000000..d8de8a619 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java @@ -0,0 +1,87 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class ExtensionCollectionMergeTest { + + @Test + public void testMergeCollectionsWithDifferentUriUrnMappings() { + String yaml1 = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:ns1:collection1\n" + + "scalar_functions:\n" + + " - name: func1\n" + + " impls:\n" + + " - args: []\n" + + " return: boolean\n"; + + String yaml2 = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:ns2:collection2\n" + + "scalar_functions:\n" + + " - name: func2\n" + + " impls:\n" + + " - args: []\n" + + " return: i32\n"; + + SimpleExtension.ExtensionCollection collection1 = + SimpleExtension.load("uri1://extensions", yaml1); + SimpleExtension.ExtensionCollection collection2 = + SimpleExtension.load("uri2://extensions", yaml2); + + SimpleExtension.ExtensionCollection merged = collection1.merge(collection2); + + assertEquals("extension:ns1:collection1", merged.getUrnFromUri("uri1://extensions")); + assertEquals("extension:ns2:collection2", merged.getUrnFromUri("uri2://extensions")); + assertEquals("uri1://extensions", merged.getUriFromUrn("extension:ns1:collection1")); + assertEquals("uri2://extensions", merged.getUriFromUrn("extension:ns2:collection2")); + + assertTrue(merged.scalarFunctions().size() >= 2); + } + + @Test + public void testMergeCollectionsWithIdenticalMappings() { + String yaml = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:shared:extension\n" + + "scalar_functions:\n" + + " - name: shared_func\n" + + " impls:\n" + + " - args: []\n" + + " return: boolean\n"; + + SimpleExtension.ExtensionCollection collection1 = SimpleExtension.load("shared://uri", yaml); + SimpleExtension.ExtensionCollection collection2 = SimpleExtension.load("shared://uri", yaml); + + SimpleExtension.ExtensionCollection merged = + assertDoesNotThrow(() -> collection1.merge(collection2)); + + assertEquals("extension:shared:extension", merged.getUrnFromUri("shared://uri")); + assertEquals("shared://uri", merged.getUriFromUrn("extension:shared:extension")); + } + + @Test + public void testMergeCollectionsWithConflictingMappings() { + String yaml1 = + "%YAML 1.2\n" + "---\n" + "urn: extension:conflict:urn1\n" + "scalar_functions: []\n"; + + String yaml2 = + "%YAML 1.2\n" + "---\n" + "urn: extension:conflict:urn2\n" + "scalar_functions: []\n"; + + SimpleExtension.ExtensionCollection collection1 = SimpleExtension.load("conflict://uri", yaml1); + SimpleExtension.ExtensionCollection collection2 = + SimpleExtension.load("conflict://uri", yaml2); // Same URI, different URN + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> collection1.merge(collection2)); + assertTrue(exception.getMessage().contains("Key already exists in map with different value")); + } +} diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java new file mode 100644 index 000000000..48f51cbbd --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java @@ -0,0 +1,60 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class ExtensionCollectionUriUrnTest { + + @Test + public void testHasUrnAndHasUri() { + String yamlContent = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:exists\n" + + "scalar_functions:\n" + + " - name: test_function\n"; + + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("file:///tmp/test.yaml", yamlContent); + + assertTrue(collection.getUrnFromUri("file:///tmp/test.yaml") != null); + assertTrue(collection.getUriFromUrn("extension:test:exists") != null); + assertFalse(collection.getUrnFromUri("nonexistent://uri") != null); + assertFalse(collection.getUriFromUrn("extension:nonexistent:urn") != null); + } + + @Test + public void testGetNonexistentMappings() { + String yamlContent = + "%YAML 1.2\n" + "---\n" + "urn: extension:test:minimal\n" + "scalar_functions: []\n"; + + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("minimal://extension", yamlContent); + + assertNull(collection.getUrnFromUri("nonexistent://uri")); + assertNull(collection.getUriFromUrn("extension:nonexistent:urn")); + } + + @Test + public void testEmptyUriThrowsException() { + String yamlContent = + "%YAML 1.2\n" + "---\n" + "urn: extension:test:empty\n" + "scalar_functions: []\n"; + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load("", yamlContent)); + assertTrue(exception.getMessage().contains("URI cannot be null or empty")); + } + + @Test + public void testNullUriThrowsException() { + String yamlContent = + "%YAML 1.2\n" + "---\n" + "urn: extension:test:null\n" + "scalar_functions: []\n"; + + // The system throws NPE when null is passed, which is expected behavior + assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load(null, yamlContent)); + } +} diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java new file mode 100644 index 000000000..4e6b52fdf --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java @@ -0,0 +1,41 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.proto.Plan; +import org.junit.jupiter.api.Test; + +public class ExtensionCollectorUriUrnTest { + + @Test + public void testExtensionCollectorScalarFuncWithoutURI() { + String uri = "test://uri"; + BidiMap uriUrnMap = new BidiMap(); + uriUrnMap.put(uri, "extension:test:basic"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + ExtensionCollector collector = new ExtensionCollector(extensionCollection); + + SimpleExtension.ScalarFunctionVariant func = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:test:basic") + .name("test_func") + .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) + .build(); + + int functionRef = collector.getFunctionReference(func); + assertEquals(1, functionRef); + + Plan.Builder planBuilder = Plan.newBuilder(); + collector.addExtensionsToPlan(planBuilder); + + Plan plan = planBuilder.build(); + assertEquals(1, plan.getExtensionUrnsCount()); + assertEquals("extension:test:basic", plan.getExtensionUrns(0).getUrn()); + + assertEquals(1, plan.getExtensionUrisCount()); + assertEquals("test://uri", plan.getExtensionUris(0).getUri()); + } +} diff --git a/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java new file mode 100644 index 000000000..f035eaccb --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java @@ -0,0 +1,621 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.proto.Plan; +import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; +import org.junit.jupiter.api.Test; + +public class ImmutableExtensionLookupUriUrnTest { + + @Test + public void testUrnResolutionWorks() { + // Create URN-only plan (normal case) + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(1) + .setUrn("extension:test:urn") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("test_func") + .setExtensionUrnReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + // Test with no ExtensionCollection (no URI/URN mapping available) + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:urn", lookup.getFunctionAnchor(1).urn()); + assertEquals("test_func", lookup.getFunctionAnchor(1).key()); + } + + @Test + public void testUriToUrnFallbackWorks() { + // Create an ExtensionCollection with URI/URN mapping + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/extensions/test", "extension:test:mapped"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + // Create URI-only plan (legacy case) + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/extensions/test") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("legacy_func") + .setExtensionUriReference(1) // References the URI anchor (deprecated field) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + // Test with URI/URN mapping - should resolve URI to URN + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:mapped", lookup.getFunctionAnchor(1).urn()); + assertEquals("legacy_func", lookup.getFunctionAnchor(1).key()); + } + + @Test + public void testUriWithoutMappingThrowsError() { + // Create URI-only plan without mapping + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/unmapped") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("unmapped_func") + .setExtensionUriReference(1) // References the URI anchor + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + // Should throw error - URI present but no mapping available + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder().from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("All resolution strategies failed")); + assertTrue(exception.getMessage().contains("http://example.com/unmapped")); + assertTrue(exception.getMessage().contains("URI <-> URN mapping")); + } + + @Test + public void testMissingUrnAndUriThrowsError() { + // Create plan with missing URN/URI reference + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("missing_func") + .setExtensionUrnReference(999) // Non-existent reference + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensions(decl).build(); + + // Should throw error - neither URN nor URI found + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder().from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("All resolution strategies failed")); + assertTrue(exception.getMessage().contains("null")); // Both URI and URN should be null + } + + // ========================================================================== + // Simple tests for all 5 resolution cases - Functions + // ========================================================================== + + @Test + public void testFunctionCase1_NonZeroUrnReference() { + // Case 1: Non-zero URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(1) + .setUrn("extension:test:case1") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case1_func") + .setExtensionUrnReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case1", lookup.getFunctionAnchor(1).urn()); + assertEquals("case1_func", lookup.getFunctionAnchor(1).key()); + } + + @Test + public void testFunctionCase2_NonZeroUriReference() { + // Case 2: Non-zero URI reference resolves via mapping + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case2", "extension:test:case2"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/case2") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case2_func") + .setExtensionUriReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case2", lookup.getFunctionAnchor(1).urn()); + assertEquals("case2_func", lookup.getFunctionAnchor(1).key()); + } + + @Test + public void testFunctionCase3_ZeroBothResolveConsistent() { + // Case 3: Both 0 references resolve to consistent URN + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case3", "extension:test:case3"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case3") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case3") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case3_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case3", lookup.getFunctionAnchor(1).urn()); + assertEquals("case3_func", lookup.getFunctionAnchor(1).key()); + } + + @Test + public void testFunctionCase3_ZeroBothResolveConflict() { + // Case 3: Both 0 references resolve but to different URNs - should throw + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/conflict", "extension:test:different"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:original") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/conflict") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("conflict_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("Conflicting URI/URN mapping")); + assertTrue(exception.getMessage().contains("These must be consistent")); + } + + @Test + public void testFunctionCase4_ZeroUrnOnly() { + // Case 4: Only 0 URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case4") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case4_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case4", lookup.getFunctionAnchor(1).urn()); + assertEquals("case4_func", lookup.getFunctionAnchor(1).key()); + } + + @Test + public void testFunctionCase5_ZeroUriOnly() { + // Case 5: Only 0 URI reference resolves + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case5", "extension:test:case5"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case5") + .build(); + + SimpleExtensionDeclaration.ExtensionFunction func = + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("case5_func") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case5", lookup.getFunctionAnchor(1).urn()); + assertEquals("case5_func", lookup.getFunctionAnchor(1).key()); + } + + // ========================================================================== + // Simple tests for all 5 resolution cases - Types + // ========================================================================== + + @Test + public void testTypeCase1_NonZeroUrnReference() { + // Case 1: Non-zero URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(1) + .setUrn("extension:test:case1") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case1_type") + .setExtensionUrnReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case1", lookup.getTypeAnchor(1).urn()); + assertEquals("case1_type", lookup.getTypeAnchor(1).key()); + } + + @Test + public void testTypeCase2_NonZeroUriReference() { + // Case 2: Non-zero URI reference resolves via mapping + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case2", "extension:test:case2"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/case2") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case2_type") + .setExtensionUriReference(1) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case2", lookup.getTypeAnchor(1).urn()); + assertEquals("case2_type", lookup.getTypeAnchor(1).key()); + } + + @Test + public void testTypeCase3_ZeroBothResolveConsistent() { + // Case 3: Both 0 references resolve to consistent URN + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case3", "extension:test:case3"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case3") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case3") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case3_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case3", lookup.getTypeAnchor(1).urn()); + assertEquals("case3_type", lookup.getTypeAnchor(1).key()); + } + + @Test + public void testTypeCase3_ZeroBothResolveConflict() { + // Case 3: Both 0 references resolve but to different URNs - should throw + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/conflict", "extension:test:different"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:original") + .build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/conflict") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("conflict_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = + Plan.newBuilder() + .addExtensionUrns(urnProto) + .addExtensionUris(uriProto) + .addExtensions(decl) + .build(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + }); + + assertTrue(exception.getMessage().contains("Conflicting URI/URN mapping")); + assertTrue(exception.getMessage().contains("These must be consistent")); + } + + @Test + public void testTypeCase4_ZeroUrnOnly() { + // Case 4: Only 0 URN reference resolves + SimpleExtensionURN urnProto = + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(0) + .setUrn("extension:test:case4") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case4_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + + assertEquals("extension:test:case4", lookup.getTypeAnchor(1).urn()); + assertEquals("case4_type", lookup.getTypeAnchor(1).key()); + } + + @Test + public void testTypeCase5_ZeroUriOnly() { + // Case 5: Only 0 URI reference resolves + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/case5", "extension:test:case5"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(0) + .setUri("http://example.com/case5") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("case5_type") + .setExtensionUrnReference(0) + .setExtensionUriReference(0) + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:test:case5", lookup.getTypeAnchor(1).urn()); + assertEquals("case5_type", lookup.getTypeAnchor(1).key()); + } + + @Test + public void testTypeUriToUrnFallbackWorks() { + // Test the same logic but for types instead of functions + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("http://example.com/types/test", "extension:types:mapped"); + + SimpleExtension.ExtensionCollection extensionCollection = + SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); + + SimpleExtensionURI uriProto = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("http://example.com/types/test") + .build(); + + SimpleExtensionDeclaration.ExtensionType type = + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(1) + .setName("legacy_type") + .setExtensionUriReference(1) // References the URI anchor + .build(); + + SimpleExtensionDeclaration decl = + SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); + + Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + + ImmutableExtensionLookup lookup = + ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); + + assertEquals("extension:types:mapped", lookup.getTypeAnchor(1).urn()); + assertEquals("legacy_type", lookup.getTypeAnchor(1).key()); + } +} diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index e3bd9b7f4..2e8fd8403 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -32,9 +32,9 @@ public class TypeExtensionTest { final SimpleExtension.ExtensionCollection extensionCollection; { - InputStream inputStream = - this.getClass().getResourceAsStream("/extensions/custom_extensions.yaml"); - extensionCollection = SimpleExtension.load(inputStream); + String path = "/extensions/custom_extensions.yaml"; + InputStream inputStream = this.getClass().getResourceAsStream(path); + extensionCollection = SimpleExtension.load(path, inputStream); } final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); diff --git a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java new file mode 100644 index 000000000..9cf215ce1 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java @@ -0,0 +1,110 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.plan.PlanProtoConverter; +import io.substrait.plan.ProtoPlanConverter; +import io.substrait.proto.Plan; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +/** + * End-to-end tests demonstrating the full URI/URN migration workflow: 1. Consume plans with mixed + * URI/URN references 2. Convert proto -> POJO using ImmutableExtensionLookup with URI/URN mapping + * 3. Convert POJO -> proto using PlanProtoConverter 4. Verify output contains proper + * extensioninformation + */ +public class UriUrnMigrationEndToEndTest { + + private final SimpleExtension.ExtensionCollection defaultExtensions = + SimpleExtension.loadDefaults(); + + /** Load a proto Plan from a JSON resource file using JsonFormat */ + private Plan loadPlanFromJson(String resourcePath) throws IOException { + try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream(resourcePath)) { + if (inputStream == null) { + throw new IOException("Resource not found: " + resourcePath); + } + + String jsonContent = + new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + .lines() + .collect(Collectors.joining("\n")); + + Plan.Builder planBuilder = Plan.newBuilder(); + JsonFormat.parser().merge(jsonContent, planBuilder); + return planBuilder.build(); + } + } + + @Test + public void testUriUrnMigrationEndToEnd() throws IOException { + + // List of (inputPath, expectedPath, extensionCollection) tuples + List testCases = + Arrays.asList( + new String[] { + "uri-urn-migration/uri-only-input-plan.json", + "uri-urn-migration/uri-only-expected-plan.json" + }, + new String[] { + "uri-urn-migration/complex-input-plan.json", + "uri-urn-migration/complex-expected-plan.json" + }, + new String[] { + "uri-urn-migration/urn-only-input-plan.json", + "uri-urn-migration/urn-only-expected-plan.json" + }, + new String[] { + "uri-urn-migration/mixed-partial-coverage-input-plan.json", + "uri-urn-migration/mixed-partial-coverage-expected-plan.json" + }, + new String[] { + "uri-urn-migration/zero-urn-resolution-input-plan.json", + "uri-urn-migration/zero-urn-resolution-expected-plan.json" + }); + + for (String[] testCase : testCases) { + String inputPath = testCase[0]; + String expectedPath = testCase[1]; + + Plan inputPlan = loadPlanFromJson(inputPath); + Plan expectedPlan = loadPlanFromJson(expectedPath); + + ProtoPlanConverter protoToPojo = new ProtoPlanConverter(defaultExtensions); + io.substrait.plan.Plan pojoPlan = protoToPojo.from(inputPlan); + + PlanProtoConverter pojoToProto = new PlanProtoConverter(defaultExtensions); + Plan actualPlan = pojoToProto.toProto(pojoPlan); + + assertEquals(expectedPlan, actualPlan); + } + } + + @Test + public void testUnresolvableUriThrowsException() throws IOException { + Plan inputPlan = loadPlanFromJson("uri-urn-migration/unresolvable-uri-plan.json"); + + ProtoPlanConverter protoToPojo = new ProtoPlanConverter(defaultExtensions); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> { + protoToPojo.from(inputPlan); + }); + + assertTrue(exception.getMessage().contains("All resolution strategies failed")); + assertTrue(exception.getMessage().contains("/functions_nonexistent.yaml")); + } +} diff --git a/core/src/test/java/io/substrait/extension/UrnValidationTest.java b/core/src/test/java/io/substrait/extension/UrnValidationTest.java index 3d3354a6f..983eb4f41 100644 --- a/core/src/test/java/io/substrait/extension/UrnValidationTest.java +++ b/core/src/test/java/io/substrait/extension/UrnValidationTest.java @@ -1,6 +1,7 @@ package io.substrait.extension; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -12,7 +13,8 @@ public class UrnValidationTest { public void testMissingUrnThrowsException() { String yamlWithoutUrn = "%YAML 1.2\n" + "---\n" + "scalar_functions:\n" + " - name: test\n"; IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load(yamlWithoutUrn)); + assertThrows( + IllegalArgumentException.class, () -> SimpleExtension.load("some/uri", yamlWithoutUrn)); assertTrue(exception.getMessage().contains("Extension YAML file must contain a 'urn' field")); } @@ -26,7 +28,8 @@ public void testInvalidUrnFormatThrowsException() { + " - name: test\n"; IllegalArgumentException exception = assertThrows( - IllegalArgumentException.class, () -> SimpleExtension.load(yamlWithInvalidUrn)); + IllegalArgumentException.class, + () -> SimpleExtension.load("some/uri", yamlWithInvalidUrn)); assertTrue( exception.getMessage().contains("URN must follow format 'extension::'")); } @@ -39,6 +42,19 @@ public void testValidUrnWorks() { + "urn: extension:test:valid\n" + "scalar_functions:\n" + " - name: test\n"; - assertDoesNotThrow(() -> SimpleExtension.load(yamlWithValidUrn)); + assertDoesNotThrow(() -> SimpleExtension.load("some/uri", yamlWithValidUrn)); + } + + @Test + public void testUriUrnMapIsPopulated() { + String yamlWithValidUrn = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:valid\n" + + "scalar_functions:\n" + + " - name: test\n"; + SimpleExtension.ExtensionCollection collection = + SimpleExtension.load("test://uri", yamlWithValidUrn); + assertEquals("extension:test:valid", collection.getUrnFromUri("test://uri")); } } diff --git a/core/src/test/resources/uri-urn-migration/complex-expected-plan.json b/core/src/test/resources/uri-urn-migration/complex-expected-plan.json new file mode 100644 index 000000000..09833ace9 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/complex-expected-plan.json @@ -0,0 +1,162 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 3, + "name": "multiply:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 5 + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 7 + } + } + } + ] + } + }, + { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 3 + } + } + }, + { + "value": { + "literal": { + "i32": 4 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result", + "product" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/complex-input-plan.json b/core/src/test/resources/uri-urn-migration/complex-input-plan.json new file mode 100644 index 000000000..da28cbb33 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/complex-input-plan.json @@ -0,0 +1,148 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 0, + "uri": "/functions_string.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 2, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUrnReference": 2 + } + }, + { + "extensionFunction": { + "functionAnchor": 3, + "name": "multiply:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 5 + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 7 + } + } + } + ] + } + }, + { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 3 + } + } + }, + { + "value": { + "literal": { + "i32": 4 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result", + "product" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-expected-plan.json b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-expected-plan.json new file mode 100644 index 000000000..2898243b3 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-expected-plan.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "multiply:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "x", + "y" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 2 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + } + ] + } + } + ] + } + }, + "names": [ + "calculation" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-input-plan.json b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-input-plan.json new file mode 100644 index 000000000..e02019925 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/mixed-partial-coverage-input-plan.json @@ -0,0 +1,110 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "multiply:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "x", + "y" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 2 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + } + ] + } + } + ] + } + }, + "names": [ + "calculation" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/unresolvable-uri-plan.json b/core/src/test/resources/uri-urn-migration/unresolvable-uri-plan.json new file mode 100644 index 000000000..769d359e7 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/unresolvable-uri-plan.json @@ -0,0 +1,73 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_nonexistent.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "nonexistent_function:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 7 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/uri-only-expected-plan.json b/core/src/test/resources/uri-urn-migration/uri-only-expected-plan.json new file mode 100644 index 000000000..1ad0150e3 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/uri-only-expected-plan.json @@ -0,0 +1,98 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 7 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/uri-only-input-plan.json b/core/src/test/resources/uri-urn-migration/uri-only-input-plan.json new file mode 100644 index 000000000..5630a00ca --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/uri-only-input-plan.json @@ -0,0 +1,80 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 7 + } + } + }, + { + "value": { + "literal": { + "i32": 3 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/urn-only-expected-plan.json b/core/src/test/resources/uri-urn-migration/urn-only-expected-plan.json new file mode 100644 index 000000000..9e47d1e74 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/urn-only-expected-plan.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "value1", + "value2" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "numbers" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 20 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/urn-only-input-plan.json b/core/src/test/resources/uri-urn-migration/urn-only-input-plan.json new file mode 100644 index 000000000..b5f5af06e --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/urn-only-input-plan.json @@ -0,0 +1,110 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUrnReference": 1 + } + }, + { + "extensionFunction": { + "functionAnchor": 2, + "name": "subtract:i32_i32", + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "value1", + "value2" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "numbers" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 20 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + }, + { + "value": { + "literal": { + "i32": 10 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/zero-urn-resolution-expected-plan.json b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-expected-plan.json new file mode 100644 index 000000000..b8788c850 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-expected-plan.json @@ -0,0 +1,104 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 1, + "extensionUrnReference": 1 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a", + "b" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/core/src/test/resources/uri-urn-migration/zero-urn-resolution-input-plan.json b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-input-plan.json new file mode 100644 index 000000000..bc4983846 --- /dev/null +++ b/core/src/test/resources/uri-urn-migration/zero-urn-resolution-input-plan.json @@ -0,0 +1,85 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 0, + "urn": "extension:io.substrait:functions_arithmetic" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 1, + "name": "add:i32_i32", + "extensionUriReference": 0, + "extensionUrnReference": 0 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "a", + "b" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": {} + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10 + } + } + }, + { + "value": { + "literal": { + "i32": 5 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "sum_result" + ] + } + } + ], + "version": { + "minorNumber": 75 + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 7e59a82a2..d351b1bca 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -56,7 +56,7 @@ public class CustomFunctionTest extends PlanTestBase { // Load custom extension into an ExtensionCollection static final SimpleExtension.ExtensionCollection extensionCollection = - SimpleExtension.load(FUNCTIONS_CUSTOM); + SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index f6ca3b150..b10fbab09 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -37,7 +37,7 @@ public class RelCopyOnWriteVisitorTest extends PlanTestBase { private static final String COUNT_DISTINCT_SUBBQUERY = "select\n" + " count(distinct l.l_orderkey),\n" - + " count(distinct l.l_orderkey) + 1,\n" + + " count(distinct l.l_orderkey) + 1,\n" + " sum(l.l_extendedprice * (1 - l.l_discount)) as revenue,\n" + " o.o_orderdate,\n" + " count(distinct o.o_shippriority)\n" diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index 595d5c169..02d67c4b5 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -29,7 +29,7 @@ object SparkExtension { final val file = "/spark.yml" private val SparkImpls: SimpleExtension.ExtensionCollection = - SimpleExtension.load(getClass.getResourceAsStream(file)) + SimpleExtension.load(file, getClass.getResourceAsStream(file)) private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = DefaultExtensionCatalog.DEFAULT_COLLECTION; From 30b2df88d31e0cde8b1488888515416cab96d284 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 1 Oct 2025 17:28:13 -0400 Subject: [PATCH 5/6] fix: fix compilation by using DEFAULT_COLLECTION --- .../io/substrait/extension/ExtensionCollector.java | 2 +- .../extension/ImmutableExtensionLookup.java | 2 +- .../java/io/substrait/plan/PlanProtoConverter.java | 3 ++- .../extension/UriUrnMigrationEndToEndTest.java | 12 ++++++------ 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index ec313d860..7ad07a6b1 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -30,7 +30,7 @@ private String getUriFromUrn(String urn) { } public ExtensionCollector() { - this(SimpleExtension.loadDefaults()); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } public ExtensionCollector(SimpleExtension.ExtensionCollection extensionCollection) { diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index a8a365dbe..3dbd5f121 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -35,7 +35,7 @@ public static class Builder { private final SimpleExtension.ExtensionCollection extensionCollection; public Builder() { - this.extensionCollection = SimpleExtension.loadDefaults(); + this.extensionCollection = DefaultExtensionCatalog.DEFAULT_COLLECTION; } public Builder(SimpleExtension.ExtensionCollection extensionCollection) { diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index 96a3333c8..5785ccd70 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -1,5 +1,6 @@ package io.substrait.plan; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.proto.Plan; @@ -16,7 +17,7 @@ public class PlanProtoConverter { private final SimpleExtension.ExtensionCollection extensionCollection; public PlanProtoConverter() { - this(SimpleExtension.loadDefaults()); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } public PlanProtoConverter(SimpleExtension.ExtensionCollection extensionCollection) { diff --git a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java index 9cf215ce1..8692284fe 100644 --- a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java +++ b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java @@ -26,9 +26,6 @@ */ public class UriUrnMigrationEndToEndTest { - private final SimpleExtension.ExtensionCollection defaultExtensions = - SimpleExtension.loadDefaults(); - /** Load a proto Plan from a JSON resource file using JsonFormat */ private Plan loadPlanFromJson(String resourcePath) throws IOException { try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream(resourcePath)) { @@ -81,10 +78,12 @@ public void testUriUrnMigrationEndToEnd() throws IOException { Plan inputPlan = loadPlanFromJson(inputPath); Plan expectedPlan = loadPlanFromJson(expectedPath); - ProtoPlanConverter protoToPojo = new ProtoPlanConverter(defaultExtensions); + ProtoPlanConverter protoToPojo = + new ProtoPlanConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); io.substrait.plan.Plan pojoPlan = protoToPojo.from(inputPlan); - PlanProtoConverter pojoToProto = new PlanProtoConverter(defaultExtensions); + PlanProtoConverter pojoToProto = + new PlanProtoConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); Plan actualPlan = pojoToProto.toProto(pojoPlan); assertEquals(expectedPlan, actualPlan); @@ -95,7 +94,8 @@ public void testUriUrnMigrationEndToEnd() throws IOException { public void testUnresolvableUriThrowsException() throws IOException { Plan inputPlan = loadPlanFromJson("uri-urn-migration/unresolvable-uri-plan.json"); - ProtoPlanConverter protoToPojo = new ProtoPlanConverter(defaultExtensions); + ProtoPlanConverter protoToPojo = + new ProtoPlanConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); IllegalStateException exception = assertThrows( From 89ce8ec3bb2e61bfa4743073927425899a84738c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Fri, 3 Oct 2025 15:15:24 -0400 Subject: [PATCH 6/6] fix: address @vbarua comments --- .../extension/AbstractExtensionLookup.java | 20 ----- .../extension/ImmutableExtensionLookup.java | 78 ++++++++++++------- .../substrait/extension/SimpleExtension.java | 21 ++--- .../ImmutableExtensionLookupUriUrnTest.java | 58 +++++++------- .../UriUrnMigrationEndToEndTest.java | 2 +- 5 files changed, 88 insertions(+), 91 deletions(-) diff --git a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java index 4693a2674..16e41f03f 100644 --- a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java @@ -13,26 +13,6 @@ public AbstractExtensionLookup( this.typeAnchorMap = typeAnchorMap; } - /** - * Gets the function anchor for a given reference (primarily for testing). - * - * @param reference The function reference - * @return The function anchor, or null if not found - */ - public SimpleExtension.FunctionAnchor getFunctionAnchor(int reference) { - return functionAnchorMap.get(reference); - } - - /** - * Gets the type anchor for a given reference (primarily for testing). - * - * @param reference The type reference - * @return The type anchor, or null if not found - */ - public SimpleExtension.TypeAnchor getTypeAnchor(int reference) { - return typeAnchorMap.get(reference); - } - @Override public SimpleExtension.ScalarFunctionVariant getScalarFunction( int reference, SimpleExtension.ExtensionCollection extensions) { diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 3dbd5f121..9b546d9dc 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -22,7 +22,7 @@ private ImmutableExtensionLookup( } public static Builder builder() { - return new Builder(); + return builder(DefaultExtensionCatalog.DEFAULT_COLLECTION); } public static Builder builder(SimpleExtension.ExtensionCollection extensionCollection) { @@ -34,10 +34,6 @@ public static class Builder { private final Map typeMap = new HashMap<>(); private final SimpleExtension.ExtensionCollection extensionCollection; - public Builder() { - this.extensionCollection = DefaultExtensionCatalog.DEFAULT_COLLECTION; - } - public Builder(SimpleExtension.ExtensionCollection extensionCollection) { if (extensionCollection == null) { throw new IllegalArgumentException("ExtensionCollection is required"); @@ -63,26 +59,40 @@ private SimpleExtension.FunctionAnchor resolveFunctionAnchor( // 1. Try non-zero URN reference if (func.getExtensionUrnReference() != 0) { String urnFromUrnRef = urnMap.get(func.getExtensionUrnReference()); - if (urnFromUrnRef != null) { - return SimpleExtension.FunctionAnchor.of(urnFromUrnRef, func.getName()); + if (urnFromUrnRef == null) { + throw new IllegalStateException( + String.format( + "Function '%s' references URN anchor %d, but no URN is registered at that anchor", + func.getName(), func.getExtensionUrnReference())); } + return SimpleExtension.FunctionAnchor.of(urnFromUrnRef, func.getName()); } // 2. Try non-zero URI reference if (func.getExtensionUriReference() != 0) { String uriFromUriRef = uriMap.get(func.getExtensionUriReference()); - if (uriFromUriRef != null) { - String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); - if (urnFromUriRef != null) { - return SimpleExtension.FunctionAnchor.of(urnFromUriRef, func.getName()); - } - // URI found but could not be resolved to URN + if (uriFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Function '%s' references URI anchor %d, but no URI is registered at that anchor", + func.getName(), func.getExtensionUriReference())); } + String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + if (urnFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Function '%s' references URI anchor %d with URI '%s', but this URI could not be resolved to a URN. " + + "Ensure a URI <-> URN mapping is registered in the ExtensionCollection.", + func.getName(), func.getExtensionUriReference(), uriFromUriRef)); + } + return SimpleExtension.FunctionAnchor.of(urnFromUriRef, func.getName()); } - /* At this point both URI and URN are 0, so we need to - first see if they both resolve. - */ + /* At this point, both URI and URN are known be 0. + * With protobufs, we cannot distinguish between 0 as an + * intentional value vs 0 as a default value. + * We perform some additional checks to below to handle this. + */ String urn = urnMap.get(func.getExtensionUrnReference()); String uri = uriMap.get(func.getExtensionUriReference()); @@ -122,26 +132,40 @@ private SimpleExtension.TypeAnchor resolveTypeAnchor( // 1. Try non-zero URN reference if (type.getExtensionUrnReference() != 0) { String urnFromUrnRef = urnMap.get(type.getExtensionUrnReference()); - if (urnFromUrnRef != null) { - return SimpleExtension.TypeAnchor.of(urnFromUrnRef, type.getName()); + if (urnFromUrnRef == null) { + throw new IllegalStateException( + String.format( + "Type '%s' references URN anchor %d, but no URN is registered at that anchor", + type.getName(), type.getExtensionUrnReference())); } + return SimpleExtension.TypeAnchor.of(urnFromUrnRef, type.getName()); } // 2. Try non-zero URI reference if (type.getExtensionUriReference() != 0) { String uriFromUriRef = uriMap.get(type.getExtensionUriReference()); - if (uriFromUriRef != null) { - String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); - if (urnFromUriRef != null) { - return SimpleExtension.TypeAnchor.of(urnFromUriRef, type.getName()); - } - // URI found but could not be resolved to URN + if (uriFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Type '%s' references URI anchor %d, but no URI is registered at that anchor", + type.getName(), type.getExtensionUriReference())); + } + String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + if (urnFromUriRef == null) { + throw new IllegalStateException( + String.format( + "Type '%s' references URI anchor %d with URI '%s', but this URI could not be resolved to a URN. " + + "Ensure a URI <-> URN mapping is registered in the ExtensionCollection.", + type.getName(), type.getExtensionUriReference(), uriFromUriRef)); } + return SimpleExtension.TypeAnchor.of(urnFromUriRef, type.getName()); } - /* At this point both URI and URN are 0, so we need to - first see if they both resolve. - */ + /* At this point, both URI and URN are known be 0. + * With protobufs, we cannot distinguish between 0 as an + * intentional value vs 0 as a default value. + * We perform some additional checks to below to handle this. + */ String urn = urnMap.get(type.getExtensionUrnReference()); String uri = uriMap.get(type.getExtensionUriReference()); diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index ed45c4f5c..3198f7a26 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -50,6 +50,9 @@ public class SimpleExtension { private static final Predicate URN_CHECKER = Pattern.compile("^extension:[^:]+:[^:]+$").asPredicate(); + // `\A` means beginning of input. Using it as a delimiter in a scanner reads in the whole file. + private static Pattern READ_WHOLE_FILE = Pattern.compile("\\A"); + private static void validateUrn(String urn) { if (urn == null || urn.trim().isEmpty()) { throw new IllegalArgumentException("URN cannot be null or empty"); @@ -554,13 +557,6 @@ public abstract static class ExtensionSignatures { @JsonProperty("urn") public abstract String urn(); - // URI is not from YAML, but from the loading context - // this only needs to be present temporarily to handle the URI -> URN migration - @Value.Default - public String uri() { - return ""; - } - @JsonProperty("scalar_functions") public abstract List scalars(); @@ -721,7 +717,7 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { * @param urn The URN to look up * @return The corresponding URI, or null if not found */ - public String getUriFromUrn(String urn) { + String getUriFromUrn(String urn) { return uriUrnMap().reverseGet(urn); } @@ -731,7 +727,7 @@ public String getUriFromUrn(String urn) { * @param uri The URI to look up * @return The corresponding URN, or null if not found */ - public String getUrnFromUri(String uri) { + String getUrnFromUri(String uri) { return uriUrnMap().get(uri); } @@ -797,10 +793,7 @@ public static ExtensionCollection load(String uri, String content) { objectMapper(urn).readValue(content, ExtensionSignatures.class); ExtensionSignatures doc = - ImmutableSimpleExtension.ExtensionSignatures.builder() - .from(docWithoutUri) - .uri(uri) - .build(); + ImmutableSimpleExtension.ExtensionSignatures.builder().from(docWithoutUri).build(); return buildExtensionCollection(uri, doc); } catch (IOException e) { @@ -810,7 +803,7 @@ public static ExtensionCollection load(String uri, String content) { public static ExtensionCollection load(String uri, InputStream stream) { try (Scanner scanner = new Scanner(stream)) { - scanner.useDelimiter("\\A"); + scanner.useDelimiter(READ_WHOLE_FILE); String content = scanner.next(); return load(uri, content); } diff --git a/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java index f035eaccb..010bdfef4 100644 --- a/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java +++ b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java @@ -36,8 +36,8 @@ public void testUrnResolutionWorks() { // Test with no ExtensionCollection (no URI/URN mapping available) ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); - assertEquals("extension:test:urn", lookup.getFunctionAnchor(1).urn()); - assertEquals("test_func", lookup.getFunctionAnchor(1).key()); + assertEquals("extension:test:urn", lookup.functionAnchorMap.get(1).urn()); + assertEquals("test_func", lookup.functionAnchorMap.get(1).key()); } @Test @@ -72,8 +72,8 @@ public void testUriToUrnFallbackWorks() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:test:mapped", lookup.getFunctionAnchor(1).urn()); - assertEquals("legacy_func", lookup.getFunctionAnchor(1).key()); + assertEquals("extension:test:mapped", lookup.functionAnchorMap.get(1).urn()); + assertEquals("legacy_func", lookup.functionAnchorMap.get(1).key()); } @Test @@ -105,7 +105,7 @@ public void testUriWithoutMappingThrowsError() { ImmutableExtensionLookup.builder().from(plan).build(); }); - assertTrue(exception.getMessage().contains("All resolution strategies failed")); + assertTrue(exception.getMessage().contains("could not be resolved to a URN")); assertTrue(exception.getMessage().contains("http://example.com/unmapped")); assertTrue(exception.getMessage().contains("URI <-> URN mapping")); } @@ -133,8 +133,8 @@ public void testMissingUrnAndUriThrowsError() { ImmutableExtensionLookup.builder().from(plan).build(); }); - assertTrue(exception.getMessage().contains("All resolution strategies failed")); - assertTrue(exception.getMessage().contains("null")); // Both URI and URN should be null + assertTrue(exception.getMessage().contains("no URN is registered at that anchor")); + assertTrue(exception.getMessage().contains("999")); // The missing anchor reference } // ========================================================================== @@ -164,8 +164,8 @@ public void testFunctionCase1_NonZeroUrnReference() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); - assertEquals("extension:test:case1", lookup.getFunctionAnchor(1).urn()); - assertEquals("case1_func", lookup.getFunctionAnchor(1).key()); + assertEquals("extension:test:case1", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case1_func", lookup.functionAnchorMap.get(1).key()); } @Test @@ -198,8 +198,8 @@ public void testFunctionCase2_NonZeroUriReference() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:test:case2", lookup.getFunctionAnchor(1).urn()); - assertEquals("case2_func", lookup.getFunctionAnchor(1).key()); + assertEquals("extension:test:case2", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case2_func", lookup.functionAnchorMap.get(1).key()); } @Test @@ -244,8 +244,8 @@ public void testFunctionCase3_ZeroBothResolveConsistent() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:test:case3", lookup.getFunctionAnchor(1).urn()); - assertEquals("case3_func", lookup.getFunctionAnchor(1).key()); + assertEquals("extension:test:case3", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case3_func", lookup.functionAnchorMap.get(1).key()); } @Test @@ -322,8 +322,8 @@ public void testFunctionCase4_ZeroUrnOnly() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); - assertEquals("extension:test:case4", lookup.getFunctionAnchor(1).urn()); - assertEquals("case4_func", lookup.getFunctionAnchor(1).key()); + assertEquals("extension:test:case4", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case4_func", lookup.functionAnchorMap.get(1).key()); } @Test @@ -357,8 +357,8 @@ public void testFunctionCase5_ZeroUriOnly() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:test:case5", lookup.getFunctionAnchor(1).urn()); - assertEquals("case5_func", lookup.getFunctionAnchor(1).key()); + assertEquals("extension:test:case5", lookup.functionAnchorMap.get(1).urn()); + assertEquals("case5_func", lookup.functionAnchorMap.get(1).key()); } // ========================================================================== @@ -388,8 +388,8 @@ public void testTypeCase1_NonZeroUrnReference() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); - assertEquals("extension:test:case1", lookup.getTypeAnchor(1).urn()); - assertEquals("case1_type", lookup.getTypeAnchor(1).key()); + assertEquals("extension:test:case1", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case1_type", lookup.typeAnchorMap.get(1).key()); } @Test @@ -422,8 +422,8 @@ public void testTypeCase2_NonZeroUriReference() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:test:case2", lookup.getTypeAnchor(1).urn()); - assertEquals("case2_type", lookup.getTypeAnchor(1).key()); + assertEquals("extension:test:case2", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case2_type", lookup.typeAnchorMap.get(1).key()); } @Test @@ -468,8 +468,8 @@ public void testTypeCase3_ZeroBothResolveConsistent() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:test:case3", lookup.getTypeAnchor(1).urn()); - assertEquals("case3_type", lookup.getTypeAnchor(1).key()); + assertEquals("extension:test:case3", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case3_type", lookup.typeAnchorMap.get(1).key()); } @Test @@ -546,8 +546,8 @@ public void testTypeCase4_ZeroUrnOnly() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); - assertEquals("extension:test:case4", lookup.getTypeAnchor(1).urn()); - assertEquals("case4_type", lookup.getTypeAnchor(1).key()); + assertEquals("extension:test:case4", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case4_type", lookup.typeAnchorMap.get(1).key()); } @Test @@ -581,8 +581,8 @@ public void testTypeCase5_ZeroUriOnly() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:test:case5", lookup.getTypeAnchor(1).urn()); - assertEquals("case5_type", lookup.getTypeAnchor(1).key()); + assertEquals("extension:test:case5", lookup.typeAnchorMap.get(1).urn()); + assertEquals("case5_type", lookup.typeAnchorMap.get(1).key()); } @Test @@ -615,7 +615,7 @@ public void testTypeUriToUrnFallbackWorks() { ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - assertEquals("extension:types:mapped", lookup.getTypeAnchor(1).urn()); - assertEquals("legacy_type", lookup.getTypeAnchor(1).key()); + assertEquals("extension:types:mapped", lookup.typeAnchorMap.get(1).urn()); + assertEquals("legacy_type", lookup.typeAnchorMap.get(1).key()); } } diff --git a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java index 8692284fe..ca2334c9c 100644 --- a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java +++ b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java @@ -104,7 +104,7 @@ public void testUnresolvableUriThrowsException() throws IOException { protoToPojo.from(inputPlan); }); - assertTrue(exception.getMessage().contains("All resolution strategies failed")); + assertTrue(exception.getMessage().contains("could not be resolved to a URN")); assertTrue(exception.getMessage().contains("/functions_nonexistent.yaml")); } }