Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.ImmutableExtensionLookup;
Expand All @@ -23,7 +24,7 @@ public class ProtoExtendedExpressionConverter {
new ExtensionCollector(), SimpleExtension.ExtensionCollection.builder().build());

public ProtoExtendedExpressionConverter() {
this(SimpleExtension.loadDefaults());
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
}

public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package io.substrait.extension;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class DefaultExtensionCatalog {
public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml";
public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
Expand All @@ -14,4 +18,28 @@ public class DefaultExtensionCatalog {
public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml";
public static final String FUNCTIONS_SET = "/functions_set.yaml";
public static final String FUNCTIONS_STRING = "/functions_string.yaml";

public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
loadDefaultCollection();
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can load the default ExtensionCollection once, and then share it everywhere.


private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
List<String> defaultFiles =
Arrays.asList(
"boolean",
"aggregate_generic",
"aggregate_approx",
"arithmetic_decimal",
"arithmetic",
"comparison",
"datetime",
"logarithmic",
"rounding",
"rounding_decimal",
"string")
.stream()
.map(c -> String.format("/functions_%s.yaml", c))
.collect(Collectors.toList());

return SimpleExtension.load(defaultFiles);
}
}
22 changes: 0 additions & 22 deletions core/src/main/java/io/substrait/extension/SimpleExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -701,27 +700,6 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) {
}
}

public static ExtensionCollection loadDefaults() {
List<String> defaultFiles =
Arrays.asList(
"boolean",
"aggregate_generic",
"aggregate_approx",
"arithmetic_decimal",
"arithmetic",
"comparison",
"datetime",
"logarithmic",
"rounding",
"rounding_decimal",
"string")
.stream()
.map(c -> String.format("/functions_%s.yaml", c))
.collect(Collectors.toList());

return load(defaultFiles);
}

public static ExtensionCollection load(List<String> resourcePaths) {
if (resourcePaths.isEmpty()) {
throw new IllegalArgumentException("Require at least one resource path.");
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/java/io/substrait/plan/ProtoPlanConverter.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.plan;

import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.ImmutableExtensionLookup;
import io.substrait.extension.SimpleExtension;
Expand All @@ -16,7 +17,7 @@ public class ProtoPlanConverter {
protected final SimpleExtension.ExtensionCollection extensionCollection;

public ProtoPlanConverter() {
this(SimpleExtension.loadDefaults());
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
}

public ProtoPlanConverter(SimpleExtension.ExtensionCollection extensionCollection) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.type.proto.ProtoTypeConverter;
Expand All @@ -24,7 +25,7 @@ public class ProtoAggregateFunctionConverter {

public ProtoAggregateFunctionConverter(
ExtensionLookup lookup, ProtoExpressionConverter protoExpressionConverter) {
this(lookup, SimpleExtension.loadDefaults(), protoExpressionConverter);
this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION, protoExpressionConverter);
}

public ProtoAggregateFunctionConverter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.expression.Expression;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.AdvancedExtension;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.hint.Hint;
Expand Down Expand Up @@ -55,7 +56,7 @@ public class ProtoRelConverter {
private final ProtoTypeConverter protoTypeConverter;

public ProtoRelConverter(ExtensionLookup lookup) {
this(lookup, SimpleExtension.loadDefaults());
this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION);
}

public ProtoRelConverter(ExtensionLookup lookup, SimpleExtension.ExtensionCollection extensions) {
Expand Down
3 changes: 2 additions & 1 deletion core/src/test/java/io/substrait/TestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.relation.ProtoRelConverter;
Expand All @@ -13,7 +14,7 @@
public abstract class TestBase {

protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection =
SimpleExtension.loadDefaults();
DefaultExtensionCatalog.DEFAULT_COLLECTION;

protected TypeCreator R = TypeCreator.REQUIRED;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import com.google.protobuf.Message;
import com.google.protobuf.TextFormat;
import com.google.protobuf.util.JsonFormat;
import io.substrait.extension.SimpleExtension;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.isthmus.FeatureBoard;
import io.substrait.isthmus.ImmutableFeatureBoard;
import io.substrait.isthmus.SqlExpressionToSubstrait;
Expand Down Expand Up @@ -87,7 +87,7 @@ public Integer call() throws Exception {
// Isthmus image is parsing SQL Expression if that argument is defined
if (sqlExpressions != null) {
SqlExpressionToSubstrait converter =
new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults());
new SqlExpressionToSubstrait(featureBoard, DefaultExtensionCatalog.DEFAULT_COLLECTION);
ExtendedExpression extendedExpression = converter.convert(sqlExpressions, createStatements);
printMessage(extendedExpression);
} else { // by default Isthmus image are parsing SQL Query
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus;

import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.config.CalciteConnectionProperty;
Expand All @@ -20,7 +21,7 @@

class SqlConverterBase {
protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
SimpleExtension.loadDefaults();
DefaultExtensionCatalog.DEFAULT_COLLECTION;

final RelDataTypeFactory factory;
final RelOptCluster relOptCluster;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.google.common.collect.ImmutableList;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.SubstraitRelNodeConverter.Context;
import io.substrait.isthmus.expression.ExpressionRexConverter;
Expand All @@ -22,7 +23,7 @@
public class CalciteCallTest extends CalciteObjs {

private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
SimpleExtension.loadDefaults();
DefaultExtensionCatalog.DEFAULT_COLLECTION;
private final ScalarFunctionConverter functionConverter =
new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), type);
private final RexExpressionConverter rexExpressionConverter =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.sql.SubstraitSqlDialect;
import io.substrait.plan.Plan;
Expand Down Expand Up @@ -249,7 +250,8 @@ private static class ReplaceCountDistinctWithApprox {
private final ReplaceCountDistinctWithApproxVisitor visitor;

public ReplaceCountDistinctWithApprox() {
visitor = new ReplaceCountDistinctWithApproxVisitor(SimpleExtension.loadDefaults());
visitor =
new ReplaceCountDistinctWithApproxVisitor(DefaultExtensionCatalog.DEFAULT_COLLECTION);
}

public Optional<Plan> modify(Plan plan) {
Expand Down
4 changes: 2 additions & 2 deletions spark/src/main/scala/io/substrait/spark/SparkExtension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package io.substrait.spark

import io.substrait.spark.expression.{ToAggregateFunction, ToWindowFunction}

import io.substrait.extension.SimpleExtension
import io.substrait.extension.{DefaultExtensionCatalog, SimpleExtension}

import java.util.Collections

Expand All @@ -32,7 +32,7 @@ object SparkExtension {
SimpleExtension.load(Collections.singletonList(uri))

private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection =
SimpleExtension.loadDefaults()
DefaultExtensionCatalog.DEFAULT_COLLECTION;

val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls)

Expand Down