From fec8f99667f8f44af2855e985f760038655488db Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 30 Sep 2025 17:28:38 -0700 Subject: [PATCH] feat: introduce DefaultExtensionCatalog.DEFAULT_COLLECTION Provide convenient static instance of default ExtensionCollection BREAKING CHANGE: removed SimpleExtension.loadDefaults --- .../ProtoExtendedExpressionConverter.java | 3 +- .../extension/DefaultExtensionCatalog.java | 28 +++++++++++++++++++ .../substrait/extension/SimpleExtension.java | 22 --------------- .../io/substrait/plan/ProtoPlanConverter.java | 3 +- .../ProtoAggregateFunctionConverter.java | 3 +- .../substrait/relation/ProtoRelConverter.java | 3 +- core/src/test/java/io/substrait/TestBase.java | 3 +- .../isthmus/cli/IsthmusEntryPoint.java | 4 +-- .../substrait/isthmus/SqlConverterBase.java | 3 +- .../io/substrait/isthmus/CalciteCallTest.java | 3 +- .../isthmus/RelCopyOnWriteVisitorTest.java | 4 ++- .../io/substrait/spark/SparkExtension.scala | 4 +-- 12 files changed, 49 insertions(+), 34 deletions(-) diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index e05aa1f5b..c658fcbce 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.ImmutableExtensionLookup; @@ -23,7 +24,7 @@ public class ProtoExtendedExpressionConverter { new ExtensionCollector(), SimpleExtension.ExtensionCollection.builder().build()); public ProtoExtendedExpressionConverter() { - this(SimpleExtension.loadDefaults()); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 46b7a920c..39d28298f 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -1,5 +1,9 @@ package io.substrait.extension; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + public class DefaultExtensionCatalog { public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml"; public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml"; @@ -14,4 +18,28 @@ public class DefaultExtensionCatalog { public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml"; public static final String FUNCTIONS_SET = "/functions_set.yaml"; public static final String FUNCTIONS_STRING = "/functions_string.yaml"; + + public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = + loadDefaultCollection(); + + private static SimpleExtension.ExtensionCollection loadDefaultCollection() { + List defaultFiles = + Arrays.asList( + "boolean", + "aggregate_generic", + "aggregate_approx", + "arithmetic_decimal", + "arithmetic", + "comparison", + "datetime", + "logarithmic", + "rounding", + "rounding_decimal", + "string") + .stream() + .map(c -> String.format("/functions_%s.yaml", c)) + .collect(Collectors.toList()); + + return SimpleExtension.load(defaultFiles); + } } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index ad0548920..14e11ca10 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -22,7 +22,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; @@ -701,27 +700,6 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { } } - public static ExtensionCollection loadDefaults() { - List defaultFiles = - Arrays.asList( - "boolean", - "aggregate_generic", - "aggregate_approx", - "arithmetic_decimal", - "arithmetic", - "comparison", - "datetime", - "logarithmic", - "rounding", - "rounding_decimal", - "string") - .stream() - .map(c -> String.format("/functions_%s.yaml", c)) - .collect(Collectors.toList()); - - return load(defaultFiles); - } - public static ExtensionCollection load(List resourcePaths) { if (resourcePaths.isEmpty()) { throw new IllegalArgumentException("Require at least one resource path."); diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index b6d5c3bc2..15145fe5c 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -1,5 +1,6 @@ package io.substrait.plan; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.ImmutableExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -16,7 +17,7 @@ public class ProtoPlanConverter { protected final SimpleExtension.ExtensionCollection extensionCollection; public ProtoPlanConverter() { - this(SimpleExtension.loadDefaults()); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } public ProtoPlanConverter(SimpleExtension.ExtensionCollection extensionCollection) { diff --git a/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java index 7c4d59df6..c17245fb0 100644 --- a/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java @@ -5,6 +5,7 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; import io.substrait.type.proto.ProtoTypeConverter; @@ -24,7 +25,7 @@ public class ProtoAggregateFunctionConverter { public ProtoAggregateFunctionConverter( ExtensionLookup lookup, ProtoExpressionConverter protoExpressionConverter) { - this(lookup, SimpleExtension.loadDefaults(), protoExpressionConverter); + this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION, protoExpressionConverter); } public ProtoAggregateFunctionConverter( diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index a055a8f97..5ed4b5d4d 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -3,6 +3,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.AdvancedExtension; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; import io.substrait.hint.Hint; @@ -55,7 +56,7 @@ public class ProtoRelConverter { private final ProtoTypeConverter protoTypeConverter; public ProtoRelConverter(ExtensionLookup lookup) { - this(lookup, SimpleExtension.loadDefaults()); + this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION); } public ProtoRelConverter(ExtensionLookup lookup, SimpleExtension.ExtensionCollection extensions) { diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 6482c0d03..6082db8e9 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.relation.ProtoRelConverter; @@ -13,7 +14,7 @@ public abstract class TestBase { protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection = - SimpleExtension.loadDefaults(); + DefaultExtensionCatalog.DEFAULT_COLLECTION; protected TypeCreator R = TypeCreator.REQUIRED; diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index 58a7e123f..35d5c6ed2 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -4,7 +4,7 @@ import com.google.protobuf.Message; import com.google.protobuf.TextFormat; import com.google.protobuf.util.JsonFormat; -import io.substrait.extension.SimpleExtension; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.isthmus.FeatureBoard; import io.substrait.isthmus.ImmutableFeatureBoard; import io.substrait.isthmus.SqlExpressionToSubstrait; @@ -87,7 +87,7 @@ public Integer call() throws Exception { // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpressions != null) { SqlExpressionToSubstrait converter = - new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); + new SqlExpressionToSubstrait(featureBoard, DefaultExtensionCatalog.DEFAULT_COLLECTION); ExtendedExpression extendedExpression = converter.convert(sqlExpressions, createStatements); printMessage(extendedExpression); } else { // by default Isthmus image are parsing SQL Query diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index e60df0b68..0f16faf6a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; @@ -20,7 +21,7 @@ class SqlConverterBase { protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - SimpleExtension.loadDefaults(); + DefaultExtensionCatalog.DEFAULT_COLLECTION; final RelDataTypeFactory factory; final RelOptCluster relOptCluster; diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java index 2f0f934a3..337fd5559 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java @@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableList; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelNodeConverter.Context; import io.substrait.isthmus.expression.ExpressionRexConverter; @@ -22,7 +23,7 @@ public class CalciteCallTest extends CalciteObjs { private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - SimpleExtension.loadDefaults(); + DefaultExtensionCatalog.DEFAULT_COLLECTION; private final ScalarFunctionConverter functionConverter = new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), type); private final RexExpressionConverter rexExpressionConverter = diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index 973f8829d..ceb01e9e3 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -6,6 +6,7 @@ import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlDialect; import io.substrait.plan.Plan; @@ -249,7 +250,8 @@ private static class ReplaceCountDistinctWithApprox { private final ReplaceCountDistinctWithApproxVisitor visitor; public ReplaceCountDistinctWithApprox() { - visitor = new ReplaceCountDistinctWithApproxVisitor(SimpleExtension.loadDefaults()); + visitor = + new ReplaceCountDistinctWithApproxVisitor(DefaultExtensionCatalog.DEFAULT_COLLECTION); } public Optional modify(Plan plan) { diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index c470c7a42..7bb28d5ed 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -18,7 +18,7 @@ package io.substrait.spark import io.substrait.spark.expression.{ToAggregateFunction, ToWindowFunction} -import io.substrait.extension.SimpleExtension +import io.substrait.extension.{DefaultExtensionCatalog, SimpleExtension} import java.util.Collections @@ -32,7 +32,7 @@ object SparkExtension { SimpleExtension.load(Collections.singletonList(uri)) private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = - SimpleExtension.loadDefaults() + DefaultExtensionCatalog.DEFAULT_COLLECTION; val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls)