From 57aba61d434d41f7f982cdc5ba9263ebe6817a9a Mon Sep 17 00:00:00 2001 From: Wesley Zheng Date: Tue, 5 May 2026 18:06:15 -0700 Subject: [PATCH 1/5] merging --- .gitignore | 25 + .mvn/wrapper/maven-wrapper.properties | 2 + README.md | 322 +--- flake.lock | 304 +-- flake.nix | 93 +- mvnw | 291 +++ mvnw.cmd | 163 ++ pom.xml | 110 ++ scripts/build-qed-prover.sh | 10 + scripts/generate-rule-json.sh | 45 + scripts/install-dependencies.sh | 17 + scripts/test-codegen.sh | 98 + scripts/test-rules.sh | 33 + .../Backends/Calcite/CalciteGenerator.java | 518 ++++++ .../qed/Backends/Calcite/CalciteTester.java | 180 ++ .../org/qed/Backends/Calcite/EmptyConfig.java | 33 + .../Generated/AggregateExtractProject.java | 41 + .../Generated/AggregateFilterTranspose.java | 44 + .../Generated/AggregateJoinJoinRemove.java | 41 + .../Generated/AggregateJoinRemove.java | 43 + .../AggregateProjectConstantToDummyJoin.java | 41 + .../Generated/AggregateProjectMerge.java | 41 + .../Generated/FilterAggregateTranspose.java | 44 + .../Calcite/Generated/FilterIntoJoin.java | 41 + .../Calcite/Generated/FilterMerge.java | 41 + .../Generated/FilterProjectTranspose.java | 41 + .../Calcite/Generated/FilterReduceFalse.java | 41 + .../Calcite/Generated/FilterReduceTrue.java | 41 + .../Generated/FilterSetOpTranspose.java | 41 + .../Calcite/Generated/IntersectMerge.java | 41 + .../Generated/JoinAddRedundantSemiJoin.java | 41 + .../Calcite/Generated/JoinCommute.java | 55 + .../Calcite/Generated/JoinConditionPush.java | 45 + .../Calcite/Generated/JoinExtractFilter.java | 41 + .../JoinPushTransitivePredicates.java | 41 + .../Calcite/Generated/JoinReduceFalse.java | 41 + .../Calcite/Generated/JoinReduceTrue.java | 41 + .../Calcite/Generated/MinusMerge.java | 41 + .../Generated/ProjectAggregateMerge.java | 41 + .../Generated/ProjectFilterTranspose.java | 41 + .../Calcite/Generated/ProjectMerge.java | 41 + .../Calcite/Generated/PruneEmptyFilter.java | 41 + .../Generated/PruneEmptyIntersect.java | 41 + .../Calcite/Generated/PruneEmptyMinus.java | 41 + .../Calcite/Generated/PruneEmptyProject.java | 41 + .../Calcite/Generated/PruneEmptyUnion.java | 41 + .../Generated/SemiJoinFilterTranspose.java | 41 + .../Calcite/Generated/UnionMerge.java | 41 + .../Generated/UnionPullUpConstants.java | 41 + .../Calcite/Generated/UnionToDistinct.java | 41 + .../qed/Backends/Calcite/HelperFunctions.java | 649 +++++++ .../Tests/AggregateExtractProjectTest.java | 58 + .../Tests/AggregateFilterTransposeTest.java | 59 + .../Tests/AggregateJoinJoinRemoveTest.java | 122 ++ .../Tests/AggregateJoinRemoveTest.java | 82 + ...gregateProjectConstantToDummyJoinTest.java | 73 + .../Tests/AggregateProjectMergeTest.java | 72 + .../Tests/FilterAggregateTransposeTest.java | 58 + .../Calcite/Tests/FilterIntoJoinTest.java | 41 + .../Calcite/Tests/FilterMergeTest.java | 36 + .../Tests/FilterProjectTransposeTest.java | 59 + .../Calcite/Tests/FilterReduceFalseTest.java | 46 + .../Calcite/Tests/FilterReduceTrueTest.java | 45 + .../Tests/FilterSetOpTransposeTest.java | 37 + .../Calcite/Tests/IntersectMergeTest.java | 36 + .../Tests/JoinAddRedundantSemiJoinTest.java | 73 + .../Calcite/Tests/JoinCommuteTest.java | 77 + .../Calcite/Tests/JoinConditionPushTest.java | 81 + .../Calcite/Tests/JoinExtractFilterTest.java | 46 + .../JoinPushTransitivePredicatesTest.java | 42 + .../Calcite/Tests/JoinReduceFalseTest.java | 42 + .../Calcite/Tests/JoinReduceTrueTest.java | 42 + .../Calcite/Tests/MinusMergeTest.java | 34 + .../Tests/ProjectAggregateMergeTest.java | 74 + .../Tests/ProjectFilterTransposeTest.java | 59 + .../Calcite/Tests/ProjectMergeTest.java | 60 + .../Calcite/Tests/PruneEmptyFilterTest.java | 46 + .../Tests/PruneEmptyIntersectTest.java | 58 + .../Calcite/Tests/PruneEmptyMinusTest.java | 51 + .../Calcite/Tests/PruneEmptyProjectTest.java | 41 + .../Calcite/Tests/PruneEmptyUnionTest.java | 62 + .../Tests/SemiJoinFilterTransposeTest.java | 48 + .../Calcite/Tests/UnionMergeTest.java | 36 + .../Tests/UnionPullUpConstantsTest.java | 87 + .../Calcite/Tests/UnionToDistinctTest.java | 58 + .../Cockroach/CockroachGenerator.java | 1657 +++++++++++++++++ .../Backends/Cockroach/CockroachTester.java | 187 ++ .../org/qed/Backends/Cockroach/CockroachTests | 1579 ++++++++++++++++ .../Generated/AggregateExtractProject.opt | 12 + .../Generated/AggregateFilterTranspose.opt | 18 + .../Generated/AggregateJoinJoinRemove.opt | 38 + .../Generated/AggregateJoinRemove.opt | 24 + .../AggregateProjectConstantToDummyJoin.opt | 12 + .../Generated/AggregateProjectMerge.opt | 16 + .../Generated/FilterAggregateTranspose.opt | 18 + .../Cockroach/Generated/FilterIntoJoin.opt | 17 + .../Cockroach/Generated/FilterMerge.opt | 13 + .../Generated/FilterProjectTranspose.opt | 19 + .../Cockroach/Generated/FilterReduceFalse.opt | 11 + .../Cockroach/Generated/FilterReduceTrue.opt | 7 + .../Generated/FilterSetOpTranspose.opt | 26 + .../Cockroach/Generated/IntersectMerge.opt | 20 + .../Generated/JoinAddRedundantSemiJoin.opt | 19 + .../Cockroach/Generated/JoinCommute.opt | 22 + .../Cockroach/Generated/JoinConditionPush.opt | 22 + .../Cockroach/Generated/JoinExtractFilter.opt | 15 + .../JoinPushTransitivePredicates.opt | 17 + .../Cockroach/Generated/JoinReduceFalse.opt | 21 + .../Cockroach/Generated/JoinReduceTrue.opt | 18 + .../Cockroach/Generated/MinusMerge.opt | 18 + .../Generated/ProjectAggregateMerge.opt | 27 + .../Generated/ProjectFilterTranspose.opt | 18 + .../Cockroach/Generated/ProjectMerge.opt | 23 + .../Cockroach/Generated/PruneEmptyFilter.opt | 7 + .../Generated/PruneEmptyIntersect.opt | 7 + .../Cockroach/Generated/PruneEmptyMinus.opt | 7 + .../Cockroach/Generated/PruneEmptyProject.opt | 8 + .../Cockroach/Generated/PruneEmptyUnion.opt | 8 + .../Generated/SemiJoinFilterTranspose.opt | 20 + .../Cockroach/Generated/UnionMerge.opt | 20 + .../Generated/UnionPullUpConstants.opt | 33 + .../Cockroach/Generated/UnionToDistinct.opt | 29 + .../qed/Backends/Cockroach/HelperFunctions.go | 916 +++++++++ .../java/org/qed/Backends/Datafusion/.envrc | 0 .../org/qed/Backends/Datafusion/Cargo.lock | 0 .../org/qed/Backends/Datafusion/Cargo.toml | 0 .../org/qed/Backends/Datafusion/README.md | 294 +++ .../Backends/Datafusion/examples}/README.md | 0 .../examples}/export_rules_to_qed.rs | 0 .../Datafusion/examples}/optimizer.rs | 0 .../examples}/optimizer_repl/mod.rs | 0 .../examples}/optimizer_repl/tables.rs | 0 .../examples}/optimizer_repl/wrappers.rs | 0 .../Backends/Datafusion/examples}/tpch/mod.rs | 0 .../Datafusion/examples}/tpch/queries.rs | 0 .../Datafusion/examples}/tpch/schema.rs | 0 .../Datafusion/examples}/tpch_optimize.rs | 0 .../examples}/user_defined_left_semi_join.rs | 0 .../org/qed/Backends/Datafusion/flake.lock | 348 ++++ .../org/qed/Backends/Datafusion/flake.nix | 50 + .../qed/Backends/Datafusion/src}/ast/empty.rs | 0 .../Backends/Datafusion/src}/ast/extension.rs | 0 .../qed/Backends/Datafusion/src}/ast/mod.rs | 0 .../Backends/Datafusion/src}/ast/opaque.rs | 0 .../Backends/Datafusion/src}/ast/pattern.rs | 0 .../Datafusion/src}/ast/relational.rs | 0 .../Backends/Datafusion/src}/ast/source.rs | 0 .../org/qed/Backends/Datafusion/src}/lib.rs | 0 .../Datafusion/src}/matcher/default.rs | 0 .../Backends/Datafusion/src}/matcher/mod.rs | 0 .../Datafusion/src}/rule/impls/README.md | 0 .../rule/impls/filter_aggregate_transpose.rs | 0 .../src}/rule/impls/filter_into_join.rs | 0 .../src}/rule/impls/filter_merge.rs | 0 .../rule/impls/filter_project_transpose.rs | 0 .../src}/rule/impls/filter_reduce_false.rs | 0 .../src}/rule/impls/filter_reduce_true.rs | 0 .../src}/rule/impls/join_associate.rs | 0 .../src}/rule/impls/join_commute.rs | 0 .../src}/rule/impls/join_condition_push.rs | 0 .../src}/rule/impls/join_extract_filter.rs | 0 .../src}/rule/impls/join_project_transpose.rs | 0 .../Datafusion/src}/rule/impls/mod.rs | 0 .../src}/rule/impls/project_merge.rs | 0 .../src}/rule/impls/project_remove.rs | 0 .../src}/rule/impls/prune_empty_filter.rs | 0 .../src}/rule/impls/prune_empty_project.rs | 0 .../src}/rule/impls/prune_empty_union.rs | 0 .../rule/impls/semi_join_filter_transpose.rs | 0 .../qed/Backends/Datafusion/src}/rule/mod.rs | 0 .../qed/Backends/Datafusion/src}/rule/test.rs | 0 .../Backends/Datafusion/src}/verifier/mod.rs | 0 .../Backends/Datafusion/src}/verifier/qed.rs | 0 .../Backends/MySQL/Generated/FilterMerge1.sql | 5 + .../Backends/MySQL/Generated/FilterMerge2.sql | 5 + .../Backends/MySQL/Generated/JoinCommute1.sql | 5 + .../Backends/MySQL/Generated/JoinCommute2.sql | 5 + .../MySQL/Generated/ProjectMerge1.sql | 5 + .../MySQL/Generated/ProjectMerge2.sql | 5 + .../qed/Backends/MySQL/MySQLGenerator.java | 155 ++ .../Backends/MySQL/Tests/FilterMerge1Test.sql | 2 + .../Backends/MySQL/Tests/FilterMerge2Test.sql | 2 + .../Backends/MySQL/Tests/JoinCommute1Test.sql | 3 + .../Backends/MySQL/Tests/JoinCommute2Test.sql | 3 + .../qed/Backends/MySQL/Tests/MySQLTester.java | 45 + .../MySQL/Tests/ProjectMerge1Test.sql | 1 + .../MySQL/Tests/ProjectMerge2Test.sql | 1 + .../qed/Backends/MySQL/Tests/script-mysql.py | 80 + .../ProxySQL/Generated/FilterMerge.sql | 6 + .../ProxySQL/Generated/FilterReduceFalse.sql | 6 + .../ProxySQL/Generated/FilterReduceTrue.sql | 6 + .../ProxySQL/Generated/JoinCommute.sql | 6 + .../ProxySQL/Generated/ProjectMerge.sql | 6 + .../Backends/ProxySQL/ProxySQLGenerator.java | 135 ++ .../ProxySQL/Tests/FilterMergeTest.sql | 1 + .../ProxySQL/Tests/FilterReduceFalseTest.sql | 1 + .../ProxySQL/Tests/FilterReduceTrueTest.sql | 1 + .../ProxySQL/Tests/JoinCommuteTest.sql | 1 + .../ProxySQL/Tests/ProjectMergeTest.sql | 1 + .../ProxySQL/Tests/script-proxysql.sh | 58 + src/main/java/org/qed/CodeGenerator.java | 273 +++ src/main/java/org/qed/Env.java | 39 + src/main/java/org/qed/JSONDeserializer.java | 365 ++++ src/main/java/org/qed/JSONSerializer.java | 228 +++ src/main/java/org/qed/Main.java | 150 ++ src/main/java/org/qed/ProjectPaths.java | 19 + src/main/java/org/qed/QedTable.java | 80 + src/main/java/org/qed/RRule.java | 115 ++ src/main/java/org/qed/RRuleInstance.java | 481 +++++ .../AggregateExtractProject.java | 23 + .../AggregateFilterTranspose.java | 22 + .../AggregateJoinJoinRemove.java | 40 + .../RRuleInstances/AggregateJoinRemove.java | 32 + .../AggregateProjectConstantToDummyJoin.java | 144 ++ .../RRuleInstances/AggregateProjectMerge.java | 23 + .../FilterAggregateTranspose.java | 22 + .../qed/RRuleInstances/FilterIntoJoin.java | 23 + .../org/qed/RRuleInstances/FilterMerge.java | 21 + .../FilterProjectTranspose.java | 20 + .../qed/RRuleInstances/FilterReduceFalse.java | 19 + .../qed/RRuleInstances/FilterReduceTrue.java | 19 + .../RRuleInstances/FilterSetOpTranspose.java | 23 + .../qed/RRuleInstances/IntersectMerge.java | 29 + .../JoinAddRedundantSemiJoin.java | 21 + .../org/qed/RRuleInstances/JoinCommute.java | 48 + .../qed/RRuleInstances/JoinConditionPush.java | 56 + .../qed/RRuleInstances/JoinExtractFilter.java | 22 + .../JoinPushTransitivePredicates.java | 23 + .../qed/RRuleInstances/JoinReduceFalse.java | 20 + .../qed/RRuleInstances/JoinReduceTrue.java | 21 + .../org/qed/RRuleInstances/MinusMerge.java | 20 + .../RRuleInstances/ProjectAggregateMerge.java | 121 ++ .../ProjectFilterTranspose.java | 20 + .../org/qed/RRuleInstances/ProjectMerge.java | 22 + .../qed/RRuleInstances/PruneEmptyFilter.java | 20 + .../RRuleInstances/PruneEmptyIntersect.java | 19 + .../qed/RRuleInstances/PruneEmptyMinus.java | 19 + .../qed/RRuleInstances/PruneEmptyProject.java | 20 + .../qed/RRuleInstances/PruneEmptyUnion.java | 19 + .../SemiJoinFilterTranspose.java | 24 + .../org/qed/RRuleInstances/UnionMerge.java | 20 + .../RRuleInstances/UnionPullUpConstants.java | 150 ++ .../qed/RRuleInstances/UnionToDistinct.java | 62 + src/main/java/org/qed/RawPlanner.java | 240 +++ src/main/java/org/qed/RelFolder.java | 26 + src/main/java/org/qed/RelJSONShuttle.java | 364 ++++ src/main/java/org/qed/RelPruner.java | 124 ++ src/main/java/org/qed/RelRN.java | 332 ++++ src/main/java/org/qed/RelRacketShuttle.java | 161 ++ src/main/java/org/qed/RelType.java | 55 + src/main/java/org/qed/RexRN.java | 157 ++ src/main/java/org/qed/RuleBuilder.java | 125 ++ src/main/java/org/qed/SExpr.java | 88 + src/main/java/org/qed/SQLJSONParser.java | 63 + src/main/java/org/qed/SchemaGenerator.java | 276 +++ .../JoinAssociate.java | 69 + .../PruneLeftEmptyJoin.java | 20 + .../PruneRightEmptyJoin.java | 22 + .../SemiJoinJoinTranspose.java | 25 + .../SemiJoinProjectTranspose.java | 23 + .../SemiJoinRemove.java | 21 + 261 files changed, 17017 insertions(+), 576 deletions(-) create mode 100644 .mvn/wrapper/maven-wrapper.properties create mode 100755 mvnw create mode 100755 mvnw.cmd create mode 100644 pom.xml create mode 100644 scripts/build-qed-prover.sh create mode 100644 scripts/generate-rule-json.sh create mode 100644 scripts/install-dependencies.sh create mode 100644 scripts/test-codegen.sh create mode 100644 scripts/test-rules.sh create mode 100644 src/main/java/org/qed/Backends/Calcite/CalciteGenerator.java create mode 100644 src/main/java/org/qed/Backends/Calcite/CalciteTester.java create mode 100644 src/main/java/org/qed/Backends/Calcite/EmptyConfig.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/AggregateExtractProject.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/AggregateFilterTranspose.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinJoinRemove.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinRemove.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectConstantToDummyJoin.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectMerge.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/FilterAggregateTranspose.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/FilterIntoJoin.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/FilterMerge.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/FilterProjectTranspose.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceFalse.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceTrue.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/FilterSetOpTranspose.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/IntersectMerge.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/JoinAddRedundantSemiJoin.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/JoinCommute.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/JoinConditionPush.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/JoinExtractFilter.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/JoinPushTransitivePredicates.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceFalse.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceTrue.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/MinusMerge.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/ProjectAggregateMerge.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/ProjectFilterTranspose.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/ProjectMerge.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyFilter.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyIntersect.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyMinus.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyProject.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyUnion.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/SemiJoinFilterTranspose.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/UnionMerge.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/UnionPullUpConstants.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Generated/UnionToDistinct.java create mode 100644 src/main/java/org/qed/Backends/Calcite/HelperFunctions.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/AggregateExtractProjectTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/AggregateFilterTransposeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinJoinRemoveTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinRemoveTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectConstantToDummyJoinTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectMergeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/FilterAggregateTransposeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/FilterIntoJoinTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/FilterMergeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/FilterProjectTransposeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceFalseTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceTrueTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/FilterSetOpTransposeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/IntersectMergeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/JoinAddRedundantSemiJoinTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/JoinCommuteTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/JoinConditionPushTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/JoinExtractFilterTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/JoinPushTransitivePredicatesTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceFalseTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceTrueTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/MinusMergeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/ProjectAggregateMergeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/ProjectFilterTransposeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/ProjectMergeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyFilterTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyIntersectTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyMinusTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyProjectTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyUnionTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/SemiJoinFilterTransposeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/UnionMergeTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/UnionPullUpConstantsTest.java create mode 100644 src/main/java/org/qed/Backends/Calcite/Tests/UnionToDistinctTest.java create mode 100644 src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java create mode 100644 src/main/java/org/qed/Backends/Cockroach/CockroachTester.java create mode 100644 src/main/java/org/qed/Backends/Cockroach/CockroachTests create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateFilterTranspose.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectConstantToDummyJoin.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterAggregateTranspose.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterIntoJoin.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterProjectTranspose.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceTrue.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterSetOpTranspose.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/IntersectMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinAddRedundantSemiJoin.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinCommute.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinPushTransitivePredicates.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/ProjectAggregateMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/ProjectFilterTranspose.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/ProjectMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/SemiJoinFilterTranspose.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/UnionMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/UnionPullUpConstants.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/UnionToDistinct.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/HelperFunctions.go rename .envrc => src/main/java/org/qed/Backends/Datafusion/.envrc (100%) rename Cargo.lock => src/main/java/org/qed/Backends/Datafusion/Cargo.lock (100%) rename Cargo.toml => src/main/java/org/qed/Backends/Datafusion/Cargo.toml (100%) create mode 100644 src/main/java/org/qed/Backends/Datafusion/README.md rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/README.md (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/export_rules_to_qed.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/optimizer.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/optimizer_repl/mod.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/optimizer_repl/tables.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/optimizer_repl/wrappers.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/tpch/mod.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/tpch/queries.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/tpch/schema.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/tpch_optimize.rs (100%) rename {examples => src/main/java/org/qed/Backends/Datafusion/examples}/user_defined_left_semi_join.rs (100%) create mode 100644 src/main/java/org/qed/Backends/Datafusion/flake.lock create mode 100644 src/main/java/org/qed/Backends/Datafusion/flake.nix rename src/{ => main/java/org/qed/Backends/Datafusion/src}/ast/empty.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/ast/extension.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/ast/mod.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/ast/opaque.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/ast/pattern.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/ast/relational.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/ast/source.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/lib.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/matcher/default.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/matcher/mod.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/README.md (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/filter_aggregate_transpose.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/filter_into_join.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/filter_merge.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/filter_project_transpose.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/filter_reduce_false.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/filter_reduce_true.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/join_associate.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/join_commute.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/join_condition_push.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/join_extract_filter.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/join_project_transpose.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/mod.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/project_merge.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/project_remove.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/prune_empty_filter.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/prune_empty_project.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/prune_empty_union.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/impls/semi_join_filter_transpose.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/mod.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/rule/test.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/verifier/mod.rs (100%) rename src/{ => main/java/org/qed/Backends/Datafusion/src}/verifier/qed.rs (100%) create mode 100644 src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge1.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge2.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute1.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute2.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge1.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge2.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/MySQLGenerator.java create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge1Test.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge2Test.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute1Test.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute2Test.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/MySQLTester.java create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge1Test.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge2Test.sql create mode 100644 src/main/java/org/qed/Backends/MySQL/Tests/script-mysql.py create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Generated/FilterMerge.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceFalse.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceTrue.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Generated/JoinCommute.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Generated/ProjectMerge.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/ProxySQLGenerator.java create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Tests/FilterMergeTest.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceFalseTest.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceTrueTest.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Tests/JoinCommuteTest.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Tests/ProjectMergeTest.sql create mode 100644 src/main/java/org/qed/Backends/ProxySQL/Tests/script-proxysql.sh create mode 100644 src/main/java/org/qed/CodeGenerator.java create mode 100644 src/main/java/org/qed/Env.java create mode 100644 src/main/java/org/qed/JSONDeserializer.java create mode 100644 src/main/java/org/qed/JSONSerializer.java create mode 100644 src/main/java/org/qed/Main.java create mode 100644 src/main/java/org/qed/ProjectPaths.java create mode 100644 src/main/java/org/qed/QedTable.java create mode 100644 src/main/java/org/qed/RRule.java create mode 100644 src/main/java/org/qed/RRuleInstance.java create mode 100644 src/main/java/org/qed/RRuleInstances/AggregateExtractProject.java create mode 100644 src/main/java/org/qed/RRuleInstances/AggregateFilterTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/AggregateJoinJoinRemove.java create mode 100644 src/main/java/org/qed/RRuleInstances/AggregateJoinRemove.java create mode 100644 src/main/java/org/qed/RRuleInstances/AggregateProjectConstantToDummyJoin.java create mode 100644 src/main/java/org/qed/RRuleInstances/AggregateProjectMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterAggregateTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/IntersectMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinCommute.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinConditionPush.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinPushTransitivePredicates.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinReduceFalse.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinReduceTrue.java create mode 100644 src/main/java/org/qed/RRuleInstances/MinusMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/ProjectAggregateMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/ProjectFilterTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/ProjectMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/PruneEmptyFilter.java create mode 100644 src/main/java/org/qed/RRuleInstances/PruneEmptyIntersect.java create mode 100644 src/main/java/org/qed/RRuleInstances/PruneEmptyMinus.java create mode 100644 src/main/java/org/qed/RRuleInstances/PruneEmptyProject.java create mode 100644 src/main/java/org/qed/RRuleInstances/PruneEmptyUnion.java create mode 100644 src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/UnionMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/UnionPullUpConstants.java create mode 100644 src/main/java/org/qed/RRuleInstances/UnionToDistinct.java create mode 100644 src/main/java/org/qed/RawPlanner.java create mode 100644 src/main/java/org/qed/RelFolder.java create mode 100644 src/main/java/org/qed/RelJSONShuttle.java create mode 100644 src/main/java/org/qed/RelPruner.java create mode 100644 src/main/java/org/qed/RelRN.java create mode 100644 src/main/java/org/qed/RelRacketShuttle.java create mode 100644 src/main/java/org/qed/RelType.java create mode 100644 src/main/java/org/qed/RexRN.java create mode 100644 src/main/java/org/qed/RuleBuilder.java create mode 100644 src/main/java/org/qed/SExpr.java create mode 100644 src/main/java/org/qed/SQLJSONParser.java create mode 100644 src/main/java/org/qed/SchemaGenerator.java create mode 100644 src/main/java/org/qed/UnprovableRRuleInstances/JoinAssociate.java create mode 100644 src/main/java/org/qed/UnprovableRRuleInstances/PruneLeftEmptyJoin.java create mode 100644 src/main/java/org/qed/UnprovableRRuleInstances/PruneRightEmptyJoin.java create mode 100644 src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinJoinTranspose.java create mode 100644 src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinProjectTranspose.java create mode 100644 src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinRemove.java diff --git a/.gitignore b/.gitignore index 7760884..eb0115b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,32 @@ # Rust !Cargo.toml !Cargo.lock +!pom.xml +!mvnw +!mvnw.cmd +!.mvn +!.mvn/** +.mvn/wrapper/maven-wrapper.jar +!scripts +!scripts/** !examples !examples/** !src !src/** + +# Java parser (RuleScript / Qed) +!parser +!parser/** + +# Parser: build & IDE (after !parser/**) +parser/target/ +parser/.idea/ +parser/*.iml +parser/.mvn/wrapper/maven-wrapper.jar +parser/.vscode +parser/.devcontainer +parser/.direnv/ +parser/.envrc + +# OS noise under tracked trees +**/.DS_Store diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties new file mode 100644 index 0000000..d853aaf --- /dev/null +++ b/.mvn/wrapper/maven-wrapper.properties @@ -0,0 +1,2 @@ +distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.11/apache-maven-3.9.11-bin.zip +wrapperUrl=https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar diff --git a/README.md b/README.md index 0462472..685510b 100644 --- a/README.md +++ b/README.md @@ -1,294 +1,122 @@ # RuleScript -A Rust DSL for building database query rewrite rules with uninterpreted symbols. RuleScript provides a minimal, pragmatic API that wraps DataFusion's native query planning while enabling rule verification and code generation. - -## What It Does - -RuleScript lets you express query optimizer rewrite rules using abstract patterns with uninterpreted symbols: - -```rust -// Pattern: source.filter(P).filter(Q) → source.filter(P AND Q) -crate::rule! { - FilterMergeRule { - schemas: { - source: (col: T), - }, - functions: { - P(T) -> Bool, - Q(T) -> Bool, - }, - from: { - let inner = crate::filter!(source, P(col)); - crate::filter!(inner, Q(col)) - }, - to: crate::filter!(source, P(col) && Q(col)), - } -} -``` - -The `P` and `Q` are uninterpreted predicates - they can represent ANY boolean expression. This means one rule definition covers infinite concrete cases. - -## Current State - -### Implemented Rules (22 total) - -**Filter Rules:** -- FilterMergeRule - Merge consecutive filters -- FilterProjectTransposeRule - Push filter below projection -- FilterAggregateTransposeRule - Push filter predicates on GROUP BY columns below aggregate -- FilterIntoJoinRule - Merge filter into join condition -- FilterReduceTrueRule - Remove filter with true predicate -- FilterReduceFalseRule - Replace filter with false predicate with empty relation - -**Project Rules:** -- ProjectMergeRule - Merge consecutive projections -- ProjectRemoveRule - Remove identity projections - -**Join Rules:** -- JoinCommuteRule - Swap join inputs -- JoinLeftConditionPushRule - Push left-table predicates down as filter on left input -- JoinRightConditionPushRule - Push right-table predicates down as filter on right input -- JoinExtractFilterRule - Extract join condition as filter above join -- JoinLeftProjectTransposeRule - Pull projection from left join input up -- JoinRightProjectTransposeRule - Pull projection from right join input up -- JoinAssociateRule - Restructure nested joins using associativity +RuleScript is an engine-agnostic domain-specific language (DSL) for developing query rewrite rules. +For details, please see our [paper](http://www2.eecs.berkeley.edu/Pubs/TechRpts/2024/EECS-2024-140.pdf). -**Semi-Join Rules:** -- LeftSemiJoinFilterTransposeRule - Pull filter above left semi-join -- RightSemiJoinFilterTransposeRule - Pull filter above right semi-join +## Build -**Prune Empty Rules:** -- PruneEmptyFilterRule - Remove filter over empty relation -- PruneEmptyProjectRule - Remove projection over empty relation -- PruneEmptyUnionLeftRule - Simplify union with empty left input -- PruneEmptyUnionRightRule - Simplify union with empty right input -- PruneEmptyUnionBothRule - Simplify union with both inputs empty - -All rules have comprehensive tests (79 unit tests + 21 doc tests, all passing). - -See `src/rule/impls/README.md` for detailed rule documentation. - -### Core Features - -- **Pattern Matching**: Full support for Filter, Project, Join, and user-defined operators -- **Predicate Decomposition**: Automatic splitting of conjunctive predicates based on column dependencies -- **Function Composition**: Support for nested function applications (e.g., `f(g(x))`) -- **Alias Handling**: Transparent matching through alias wrappers -- **Column Abstraction**: Smart column pattern matching that works with field partitions -- **User-Defined Operators**: Extensibility for custom logical operators with pattern matching and QED verification support -- **QED Export**: Serialization to QED format for rule verification, including EXISTS subqueries with outer column references -- **DataFusion Integration**: RuleWrapper adapter for seamless optimizer integration - -### Architecture - -``` -src/ - ast/ - opaque.rs - Abstract types, fields, schemas - relational.rs - Logical plan patterns (Source, Filter, Project, Join) - pattern.rs - Pattern functions (ScalarPattern, AggregatePattern) - extension.rs - User-defined operator support (UserDefinedLogicalOperator trait) - source.rs - Source node implementation - matcher/ - mod.rs - PatternMatcher trait and error types - default.rs - DefaultMatcher with full pattern matching logic - rule/ - mod.rs - Rule traits (RewriteRule, ApplicableRule) - test.rs - Test utilities (table helpers) - impls/ - Concrete rule implementations - verifier/ - mod.rs - Verifier trait for rule verification - qed.rs - QED format serialization with subquery support - lib.rs - Public API exports -examples/ - optimizer_repl/ - Interactive demo with all rules - user_defined_left_semi_join.rs - Example of user-defined operator with EXISTS semantics -``` +The project targets **Java 25** ([OpenJDK](https://openjdk.org/) / Temurin builds). Build with Maven: -## Quick Start - -### Define a Rule - -```rust -crate::rule! { - MyRule { - schemas: { - source: (x: T), - }, - functions: { - P(T) -> Bool, - }, - from: crate::filter!(source, P(x)), - to: source, // Remove the filter - } -} +```sh +./mvnw compile -q ``` -### Apply a Rule +## Generate Rules -```rust -use rulescript::rule::{ApplicableRule, impls::FilterMergeRule}; +Rules are generated per backend by running the corresponding tester. First build a classpath: -let rule = FilterMergeRule; -let optimized_plan = rule.try_apply(&concrete_plan)?; +```sh +./mvnw dependency:build-classpath -q -DincludeTypes=jar -Dmdep.outputFile=/tmp/cp.txt ``` -### Integrate with DataFusion +Then run the tester for the target backend: -```rust -use rulescript::rule::RuleWrapper; -use datafusion::optimizer::Optimizer; +```sh +# CockroachDB +java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.Cockroach.CockroachTester -let optimizer_rule = RuleWrapper::new(FilterMergeRule); -optimizer.add_rule(Arc::new(optimizer_rule)); -``` +# Apache Calcite +java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.Calcite.CalciteTester -## Run Interactive Demo - -```bash -cargo run --example optimizer_repl +# MySQL +java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.MySQL.Tests.MySQLTester ``` -Interactive demonstration with optimization rules: -- Choose which rules to apply -- See before/after query plans -- Real SQL parsing with DataFusion - -See `examples/README.md` for detailed usage. +Generated rule files are written to each backend's `Generated/` directory. -## Run Tests +## Adding Rules -```bash -# All tests -cargo test +Rules are defined in `src/main/java/org/qed/RRuleInstances/` as Java records implementing `RRule`. Each rule provides a `before()` pattern and an `after()` transformation in terms of RuleScript's relational algebra operators. The generators pick up every file in that directory automatically. -# Specific rule -cargo test filter_merge +**Example: `FilterMerge`** -# With output -cargo test -- --nocapture +```java +// src/main/java/org/qed/RRuleInstances/FilterMerge.java +public record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); -# Clippy checks -cargo clippy --all-targets -``` + @Override + public RelRN before() { + return source.filter(inner).filter(outer); // source.filter(P).filter(Q) + } -## Macro Reference - -### rule! - Define Complete Rules - -```rust -crate::rule! { - RuleName { - schemas: { - input_name: (field: Type), - other_input: (x: T1, y: T2), - }, - functions: { - FuncName(InputType) -> OutputType, - Predicate(T1, T2) -> Bool, - }, - from: { /* pattern to match */ }, - to: { /* replacement pattern */ }, + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); // source.filter(P AND Q) } } ``` -### Plan Construction Macros +Running the generators will produce: +- `src/main/java/org/qed/Backends/Calcite/Generated/FilterMerge.java` — the Apache Calcite rule implementation +- `src/main/java/org/qed/Backends/Cockroach/Generated/FilterMerge.opt` — the CockroachDB optgen rule -```rust -// Filter -crate::filter!(source, predicate) +To also add a Calcite test, create `src/main/java/org/qed/Backends/Calcite/Tests/FilterMergeTest.java` with a `public static void runTest()` method that constructs `before` and `after` plans using `RuleBuilder` and calls `tester.verify(runner, before, after)`. The CalciteTester discovers and runs all `*Test.java` files in that directory automatically. -// Project -crate::project!(source, [expr1, expr2]) -crate::project!(source, [expr as alias]) - -// Join -crate::join!(left, right, Inner, condition) -crate::join!(left, right, Left, condition) -``` +For a full description of the rule language and available operators, see the [paper](http://www2.eecs.berkeley.edu/Pubs/TechRpts/2024/EECS-2024-140.pdf). -## Theoretical Foundation +## Apache DataFusion Backend -Based on the paper "RuleScript: A DSL for Query Optimizer Rules" which addresses the challenge of correctly implementing hundreds of rewrite rules in modern optimizers. +A separate Rust implementation targeting Apache DataFusion is available at [here](https://github.com/qed-solver/rulescript). -**Key Concepts:** -- **Uninterpreted Symbols**: Abstract types/functions represent families of concrete queries -- **Pattern Matching**: Declarative patterns with automatic instantiation -- **Verification**: Rules can be verified via QED solver (future integration) +## Test Cases -**Why This Approach:** +### Apache Calcite -Modern query optimizers suffer from: -1. Error-prone manual implementation (100-200+ rules per optimizer) -2. Difficult to verify correctness -3. Redundant code across similar rules +Individual rule tests live in `src/main/java/org/qed/Backends/Calcite/Tests/`. All tests are run automatically when the Calcite tester is invoked: -RuleScript solves this by: -1. One rule definition → many concrete applications -2. Automated verification possible -3. Declarative patterns reduce implementation complexity - -## Key Design Decisions - -**Pattern Matching:** -- Column patterns match based on field partitions (one pattern can match multiple columns) -- Predicates decompose automatically based on column dependencies -- Function composition works through context mapping - -**Implementation:** -- Minimal abstraction over DataFusion's native types -- No async/tokio in tests (fast, simple tests) -- Smart defaults (all types map to Binary for uniformity) -- Functions are UDFs that error on execution (pattern-only) - -**Rule Application:** -- DefaultMatcher manages three binding types: fields, functions, sources -- Context-preserving instantiation (bindings stay in their context) -- Recursive plan transformation with captured bindings - -## External Resources - -The project references two external directories not tracked in git: - -- `parser/` - Java implementation with QED-verified rules -- `calcite/` - Apache Calcite source for rule reference +```sh +java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.Calcite.CalciteTester +``` -See `PARSER_AND_CALCITE_NOTES.txt` for details on these directories. +### CockroachDB -## Future Work +The generated `.opt` rule files live in `src/main/java/org/qed/Backends/Cockroach/Generated/`. -**Near-term:** -- More complex join rules (with 4-predicate decomposition) -- Rule families with meta-variables -- Additional user-defined operator examples +To run them against CockroachDB: -**Long-term:** -- SMT solver integration for automated verification -- Code generation adapters for different engines -- Performance optimizations for pattern matching +1. Clone the [CockroachDB repository](https://github.com/cockroachdb/cockroach) and check out commit `4b80cd59c6299f26b2b4f02a96064d5127ccad94` — this is the exact state of the codebase the rules were developed against. -## Known Limitations +2. Copy the generated rule files and test data into the CockroachDB source tree: + - Rule files → `pkg/sql/opt/norm/rules/` + - Test data → `pkg/sql/opt/norm/testdata/rules/CockroachTests` -**Current Implementation:** -- Join rules only support INNER joins (OUTER joins require IS NOT NULL predicates) -- Some complex expression types not yet handled (SIMILAR TO with escape) -- No optimization for pattern matching efficiency +3. Check your environment is set up correctly: + ```sh + ./dev doctor + ``` -**Design Constraints:** -- All abstract symbols must bind to at least one match (no optional predicates) -- Single binding per symbol per rule application -- Strict column validation (all references must exist in context) +4. Build CockroachDB: + ```sh + ./dev build + ``` -## Dependencies +5. Run the CockroachDB tests: + ```sh + ./dev test pkg/sql/opt/norm -f=TestNormRules/CockroachTests -v + ``` -- datafusion - Query planning framework -- thiserror - Error handling macros +## License -## Status +Copyright 2026 The Qed Team -Active development. Core pattern matching complete. User-defined operator support implemented. QED export working with subquery support. API stabilizing. +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this project except in compliance with +the License. You may obtain a copy of the License at -**Test Status**: 79 unit tests + 21 doc tests passing + http://www.apache.org/licenses/LICENSE-2.0 -The project emphasizes correctness and extensibility. Pattern matching, instantiation, and user-defined operators are fully implemented with 22 working rules demonstrating the approach works with real DataFusion plans. QED serialization enables rule verification including complex cases like EXISTS subqueries with outer column references. +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an " +AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific +language governing permissions and limitations under the License. diff --git a/flake.lock b/flake.lock index fbbf408..b8e3f20 100644 --- a/flake.lock +++ b/flake.lock @@ -1,129 +1,18 @@ { "nodes": { - "cachix": { - "inputs": { - "devenv": [ - "devenv" - ], - "flake-compat": [ - "devenv", - "flake-compat" - ], - "git-hooks": [ - "devenv", - "git-hooks" - ], - "nixpkgs": [ - "devenv", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1760971495, - "narHash": "sha256-IwnNtbNVrlZIHh7h4Wz6VP0Furxg9Hh0ycighvL5cZc=", - "owner": "cachix", - "repo": "cachix", - "rev": "c5bfd933d1033672f51a863c47303fc0e093c2d2", - "type": "github" - }, - "original": { - "owner": "cachix", - "ref": "latest", - "repo": "cachix", - "type": "github" - } - }, - "devenv": { - "inputs": { - "cachix": "cachix", - "flake-compat": "flake-compat", - "flake-parts": "flake-parts", - "git-hooks": "git-hooks", - "nix": "nix", - "nixd": "nixd", - "nixpkgs": [ - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1767733209, - "narHash": "sha256-V1YN5JM1+/+MaiBH5puIjkjPssV8QNyFRT8EmCTurDY=", - "owner": "cachix", - "repo": "devenv", - "rev": "32a795ac142f4578aa5f6ecc8eafb79d253d99ae", - "type": "github" - }, - "original": { - "owner": "cachix", - "repo": "devenv", - "type": "github" - } - }, - "flake-compat": { + "cvc5-src": { "flake": false, "locked": { - "lastModified": 1761588595, - "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "lastModified": 1708192045, + "narHash": "sha256-hLb5qvWhAKeaGpR1HMH3URVXLuXjuvqapjx+loCa7Nc=", + "owner": "cvc5", + "repo": "cvc5", + "rev": "80878ee024f58a9a883479eed5dfe06402109e94", "type": "github" }, "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-compat_2": { - "flake": false, - "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-parts": { - "inputs": { - "nixpkgs-lib": [ - "devenv", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1760948891, - "narHash": "sha256-TmWcdiUUaWk8J4lpjzu4gCGxWY6/Ok7mOK4fIFfBuU4=", - "owner": "hercules-ci", - "repo": "flake-parts", - "rev": "864599284fc7c0ba6357ed89ed5e2cd5040f0c04", - "type": "github" - }, - "original": { - "owner": "hercules-ci", - "repo": "flake-parts", - "type": "github" - } - }, - "flake-root": { - "locked": { - "lastModified": 1723604017, - "narHash": "sha256-rBtQ8gg+Dn4Sx/s+pvjdq3CB2wQNzx9XGFq/JVGCB6k=", - "owner": "srid", - "repo": "flake-root", - "rev": "b759a56851e10cb13f6b8e5698af7b59c44be26e", - "type": "github" - }, - "original": { - "owner": "srid", - "repo": "flake-root", + "owner": "cvc5", + "repo": "cvc5", "type": "github" } }, @@ -132,11 +21,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "lastModified": 1705309234, + "narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=", "owner": "numtide", "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26", "type": "github" }, "original": { @@ -145,164 +34,27 @@ "type": "github" } }, - "git-hooks": { - "inputs": { - "flake-compat": [ - "devenv", - "flake-compat" - ], - "gitignore": "gitignore", - "nixpkgs": [ - "devenv", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1760663237, - "narHash": "sha256-BflA6U4AM1bzuRMR8QqzPXqh8sWVCNDzOdsxXEguJIc=", - "owner": "cachix", - "repo": "git-hooks.nix", - "rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37", - "type": "github" - }, - "original": { - "owner": "cachix", - "repo": "git-hooks.nix", - "type": "github" - } - }, - "gitignore": { - "inputs": { - "nixpkgs": [ - "devenv", - "git-hooks", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1709087332, - "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", - "owner": "hercules-ci", - "repo": "gitignore.nix", - "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", - "type": "github" - }, - "original": { - "owner": "hercules-ci", - "repo": "gitignore.nix", - "type": "github" - } - }, - "nix": { - "inputs": { - "flake-compat": [ - "devenv", - "flake-compat" - ], - "flake-parts": [ - "devenv", - "flake-parts" - ], - "git-hooks-nix": [ - "devenv", - "git-hooks" - ], - "nixpkgs": [ - "devenv", - "nixpkgs" - ], - "nixpkgs-23-11": [ - "devenv" - ], - "nixpkgs-regression": [ - "devenv" - ] - }, - "locked": { - "lastModified": 1766922625, - "narHash": "sha256-O0wExzdYqSNqbPYCQhUWeoKlDa7q6wxhuWiHolxqdl8=", - "owner": "cachix", - "repo": "nix", - "rev": "c62c4bdb6673871ae5cdc51c498df6292d5169aa", - "type": "github" - }, - "original": { - "owner": "cachix", - "ref": "devenv-2.32", - "repo": "nix", - "type": "github" - } - }, - "nixd": { - "inputs": { - "flake-parts": [ - "devenv", - "flake-parts" - ], - "flake-root": "flake-root", - "nixpkgs": [ - "devenv", - "nixpkgs" - ], - "treefmt-nix": "treefmt-nix" - }, - "locked": { - "lastModified": 1763964548, - "narHash": "sha256-JTRoaEWvPsVIMFJWeS4G2isPo15wqXY/otsiHPN0zww=", - "owner": "nix-community", - "repo": "nixd", - "rev": "d4bf15e56540422e2acc7bc26b20b0a0934e3f5e", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "nixd", - "type": "github" - } - }, "nixpkgs": { "locked": { - "lastModified": 1767640445, - "narHash": "sha256-UWYqmD7JFBEDBHWYcqE6s6c77pWdcU/i+bwD6XxMb8A=", + "lastModified": 1708247094, + "narHash": "sha256-H2VS7VwesetGDtIaaz4AMsRkPoSLEVzL/Ika8gnbUnE=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "9f0c42f8bc7151b8e7e5840fb3bd454ad850d8c5", + "rev": "045b51a3ae66f673ed44b5bbd1f4a341d96703bf", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-unstable", + "ref": "nixpkgs-unstable", "repo": "nixpkgs", "type": "github" } }, - "nixpkgs-python": { - "inputs": { - "flake-compat": "flake-compat_2", - "nixpkgs": [ - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1765052656, - "narHash": "sha256-DrMjrjxMttbGDoVxr/xke0ihd5GVd6fyUVsjuepEsCc=", - "owner": "cachix", - "repo": "nixpkgs-python", - "rev": "04b27dbad2e004cb237db202f21154eea3c4f89f", - "type": "github" - }, - "original": { - "owner": "cachix", - "repo": "nixpkgs-python", - "type": "github" - } - }, "root": { "inputs": { - "devenv": "devenv", + "cvc5-src": "cvc5-src", "flake-utils": "flake-utils", - "nixpkgs": "nixpkgs", - "nixpkgs-python": "nixpkgs-python" + "nixpkgs": "nixpkgs" } }, "systems": { @@ -319,28 +71,6 @@ "repo": "default", "type": "github" } - }, - "treefmt-nix": { - "inputs": { - "nixpkgs": [ - "devenv", - "nixd", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1734704479, - "narHash": "sha256-MMi74+WckoyEWBRcg/oaGRvXC9BVVxDZNRMpL+72wBI=", - "owner": "numtide", - "repo": "treefmt-nix", - "rev": "65712f5af67234dad91a5a4baee986a8b62dbf8f", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "treefmt-nix", - "type": "github" - } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 1a1c5b7..e7feadf 100644 --- a/flake.nix +++ b/flake.nix @@ -1,50 +1,59 @@ { inputs = { - devenv = { - inputs.nixpkgs.follows = "nixpkgs"; - url = "github:cachix/devenv"; - }; + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; flake-utils.url = "github:numtide/flake-utils"; - nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; - nixpkgs-python = { - inputs.nixpkgs.follows = "nixpkgs"; - url = "github:cachix/nixpkgs-python"; + cvc5-src = { + flake = false; + url = "github:cvc5/cvc5"; }; }; - outputs = inputs @ { - self, - devenv, - flake-utils, - nixpkgs, - ... - }: - flake-utils.lib.eachDefaultSystem (system: let - pkgs = import nixpkgs { - inherit system; - config.allowUnfree = true; - }; - in { - packages = { - devenv-up = self.devShells.${system}.default.config.procfileScript; - devenv-test = self.devShells.${system}.default.config.test; - }; + outputs = { self, nixpkgs, flake-utils, cvc5-src }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + cvc5-pname = "cvc5-1.0.5"; + cvc5-java = pkgs.stdenv.mkDerivation { + name = cvc5-pname; + src = cvc5-src; + + nativeBuildInputs = with pkgs; [ pkg-config cmake flex ]; + + buildInputs = with pkgs; [ + cadical.dev + symfpu + gmp + gtest + libantlr3c + antlr3_4 + boost + jdk21 + (python3.withPackages (ps: with ps; [ pyparsing toml ])) + ]; + + cmakeFlags = [ + "-DBUILD_BINDINGS_JAVA=ON" + "-DBUILD_SHARED_LIBS=1" + "-DCMAKE_BUILD_TYPE=Production" + "-DANTLR3_JAR=${pkgs.antlr3_4}/lib/antlr/antlr-3.4-complete.jar" + ]; + + preConfigure = '' + patchShebangs ./src/ + ''; - devShells.default = devenv.lib.mkShell { - inherit inputs pkgs; - modules = [ - { - git-hooks.hooks.alejandra.enable = true; - languages = { - nix.enable = true; - rust.enable = true; - }; - packages = with pkgs; [ - cvc5 - opencode - ]; - } - ]; - }; - }); + }; + in + { + devShells.default = pkgs.mkShell { + packages = with pkgs; [ + cvc5-java + jdk25_headless + maven + jetbrains.idea-community + ]; + CVC5_JAVA = "${cvc5-java}/share/java/cvc5.jar"; + LD_LIBRARY_PATH = pkgs.lib.strings.makeLibraryPath [ cvc5-java ]; + }; + }); } diff --git a/mvnw b/mvnw new file mode 100755 index 0000000..26794fe --- /dev/null +++ b/mvnw @@ -0,0 +1,291 @@ +#!/bin/sh +# ---------------------------------------------------------------------------- +# Maven Start Up Batch script +# +# Required ENV vars: +# ------------------ +# JAVA_HOME - location of a JDK home dir +# +# Optional ENV vars +# ----------------- +# M2_HOME - location of maven2's installed home dir +# MAVEN_OPTS - parameters passed to the Java VM when running Maven +# e.g. to debug Maven itself, use +# set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 +# MAVEN_SKIP_RC - flag to disable loading of mavenrc files +# ---------------------------------------------------------------------------- + +if [ -z "$MAVEN_SKIP_RC" ] ; then + + if [ -f /etc/mavenrc ] ; then + . /etc/mavenrc + fi + + if [ -f "$HOME/.mavenrc" ] ; then + . "$HOME/.mavenrc" + fi + +fi + +# OS specific support. $var _must_ be set to either true or false. +cygwin=false; +darwin=false; +mingw=false +case "`uname`" in + CYGWIN*) cygwin=true ;; + MINGW*) mingw=true;; + Darwin*) darwin=true + # Use /usr/libexec/java_home if available, otherwise fall back to /Library/Java/Home + # See https://developer.apple.com/library/mac/qa/qa1170/_index.html + if [ -z "$JAVA_HOME" ]; then + if [ -x "/usr/libexec/java_home" ]; then + export JAVA_HOME="`/usr/libexec/java_home`" + else + export JAVA_HOME="/Library/Java/Home" + fi + fi + ;; +esac + +if [ -z "$JAVA_HOME" ] ; then + if [ -r /etc/gentoo-release ] ; then + JAVA_HOME=`java-config --jre-home` + fi +fi + +if [ -z "$M2_HOME" ] ; then + ## resolve links - $0 may be a link to maven's home + PRG="$0" + + # need this for relative symlinks + while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG="`dirname "$PRG"`/$link" + fi + done + + saveddir=`pwd` + + M2_HOME=`dirname "$PRG"`/.. + + # make it fully qualified + M2_HOME=`cd "$M2_HOME" && pwd` + + cd "$saveddir" + # echo Using m2 at $M2_HOME +fi + +# For Cygwin, ensure paths are in UNIX format before anything is touched +if $cygwin ; then + [ -n "$M2_HOME" ] && + M2_HOME=`cygpath --unix "$M2_HOME"` + [ -n "$JAVA_HOME" ] && + JAVA_HOME=`cygpath --unix "$JAVA_HOME"` + [ -n "$CLASSPATH" ] && + CLASSPATH=`cygpath --path --unix "$CLASSPATH"` +fi + +# For Mingw, ensure paths are in UNIX format before anything is touched +if $mingw ; then + [ -n "$M2_HOME" ] && + M2_HOME="`(cd "$M2_HOME"; pwd)`" + [ -n "$JAVA_HOME" ] && + JAVA_HOME="`(cd "$JAVA_HOME"; pwd)`" +fi + +if [ -z "$JAVA_HOME" ]; then + javaExecutable="`which javac`" + if [ -n "$javaExecutable" ] && ! [ "`expr \"$javaExecutable\" : '\([^ ]*\)'`" = "no" ]; then + # readlink(1) is not available as standard on Solaris 10. + readLink=`which readlink` + if [ ! `expr "$readLink" : '\([^ ]*\)'` = "no" ]; then + if $darwin ; then + javaHome="`dirname \"$javaExecutable\"`" + javaExecutable="`cd \"$javaHome\" && pwd -P`/javac" + else + javaExecutable="`readlink -f \"$javaExecutable\"`" + fi + javaHome="`dirname \"$javaExecutable\"`" + javaHome=`expr "$javaHome" : '\(.*\)/bin'` + JAVA_HOME="$javaHome" + export JAVA_HOME + fi + fi +fi + +if [ -z "$JAVACMD" ] ; then + if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + else + JAVACMD="`which java`" + fi +fi + +if [ ! -x "$JAVACMD" ] ; then + echo "Error: JAVA_HOME is not defined correctly." >&2 + echo " We cannot execute $JAVACMD" >&2 + exit 1 +fi + +if [ -z "$JAVA_HOME" ] ; then + echo "Warning: JAVA_HOME environment variable is not set." +fi + +CLASSWORLDS_LAUNCHER=org.codehaus.plexus.classworlds.launcher.Launcher + +# traverses directory structure from process work directory to filesystem root +# first directory with .mvn subdirectory is considered project base directory +find_maven_basedir() { + + if [ -z "$1" ] + then + echo "Path not specified to find_maven_basedir" + return 1 + fi + + basedir="$1" + wdir="$1" + while [ "$wdir" != '/' ] ; do + if [ -d "$wdir"/.mvn ] ; then + basedir=$wdir + break + fi + # workaround for JBEAP-8937 (on Solaris 10/Sparc) + if [ -d "${wdir}" ]; then + wdir=`cd "$wdir/.."; pwd` + fi + # end of workaround + done + echo "${basedir}" +} + +# concatenates all lines of a file +concat_lines() { + if [ -f "$1" ]; then + echo "$(tr -s '\n' ' ' < "$1")" + fi +} + +BASE_DIR=`find_maven_basedir "$(pwd)"` +if [ -z "$BASE_DIR" ]; then + exit 1; +fi + +########################################################################################## +# Extension to allow automatically downloading the maven-wrapper.jar from Maven-central +# This allows using the maven wrapper in projects that prohibit checking in binary data. +########################################################################################## +if [ -r "$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" ]; then + if [ "$MVNW_VERBOSE" = true ]; then + echo "Found .mvn/wrapper/maven-wrapper.jar" + fi +else + if [ "$MVNW_VERBOSE" = true ]; then + echo "Couldn't find .mvn/wrapper/maven-wrapper.jar, downloading it ..." + fi + if [ -n "$MVNW_REPOURL" ]; then + jarUrl="$MVNW_REPOURL/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" + else + jarUrl="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" + fi + while IFS="=" read key value; do + case "$key" in (wrapperUrl) jarUrl="$value"; break ;; + esac + done < "$BASE_DIR/.mvn/wrapper/maven-wrapper.properties" + if [ "$MVNW_VERBOSE" = true ]; then + echo "Downloading from: $jarUrl" + fi + wrapperJarPath="$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" + if $cygwin; then + wrapperJarPath=`cygpath --path --windows "$wrapperJarPath"` + fi + + if command -v wget > /dev/null; then + if [ "$MVNW_VERBOSE" = true ]; then + echo "Found wget ... using wget" + fi + if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then + wget "$jarUrl" -O "$wrapperJarPath" + else + wget --http-user=$MVNW_USERNAME --http-password=$MVNW_PASSWORD "$jarUrl" -O "$wrapperJarPath" + fi + elif command -v curl > /dev/null; then + if [ "$MVNW_VERBOSE" = true ]; then + echo "Found curl ... using curl" + fi + if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then + curl -o "$wrapperJarPath" "$jarUrl" -f + else + curl --user $MVNW_USERNAME:$MVNW_PASSWORD -o "$wrapperJarPath" "$jarUrl" -f + fi + + else + if [ "$MVNW_VERBOSE" = true ]; then + echo "Falling back to using Java to download" + fi + javaClass="$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.java" + # For Cygwin, switch paths to Windows format before running javac + if $cygwin; then + javaClass=`cygpath --path --windows "$javaClass"` + fi + if [ -e "$javaClass" ]; then + if [ ! -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then + if [ "$MVNW_VERBOSE" = true ]; then + echo " - Compiling MavenWrapperDownloader.java ..." + fi + # Compiling the Java class + ("$JAVA_HOME/bin/javac" "$javaClass") + fi + if [ -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then + # Running the downloader + if [ "$MVNW_VERBOSE" = true ]; then + echo " - Running MavenWrapperDownloader.java ..." + fi + ("$JAVA_HOME/bin/java" -cp .mvn/wrapper MavenWrapperDownloader "$MAVEN_PROJECTBASEDIR") + fi + fi + fi +fi +########################################################################################## +# End of extension +########################################################################################## + +export MAVEN_PROJECTBASEDIR=${MAVEN_BASEDIR:-"$BASE_DIR"} +if [ "$MVNW_VERBOSE" = true ]; then + echo $MAVEN_PROJECTBASEDIR +fi +MAVEN_OPTS="$(concat_lines "$MAVEN_PROJECTBASEDIR/.mvn/jvm.config") $MAVEN_OPTS" + +# For Cygwin, switch paths to Windows format before running java +if $cygwin; then + [ -n "$M2_HOME" ] && + M2_HOME=`cygpath --path --windows "$M2_HOME"` + [ -n "$JAVA_HOME" ] && + JAVA_HOME=`cygpath --path --windows "$JAVA_HOME"` + [ -n "$CLASSPATH" ] && + CLASSPATH=`cygpath --path --windows "$CLASSPATH"` + [ -n "$MAVEN_PROJECTBASEDIR" ] && + MAVEN_PROJECTBASEDIR=`cygpath --path --windows "$MAVEN_PROJECTBASEDIR"` +fi + +# Provide a "standardized" way to retrieve the CLI args that will +# work with both Windows and non-Windows executions. +MAVEN_CMD_LINE_ARGS="$MAVEN_CONFIG $@" +export MAVEN_CMD_LINE_ARGS + +WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain + +exec "$JAVACMD" \ + $MAVEN_OPTS \ + -classpath "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar" \ + "-Dmaven.home=${M2_HOME}" "-Dmaven.multiModuleProjectDirectory=${MAVEN_PROJECTBASEDIR}" \ + ${WRAPPER_LAUNCHER} $MAVEN_CONFIG "$@" diff --git a/mvnw.cmd b/mvnw.cmd new file mode 100755 index 0000000..8d5079b --- /dev/null +++ b/mvnw.cmd @@ -0,0 +1,163 @@ +@REM ---------------------------------------------------------------------------- +@REM Maven Start Up Batch script +@REM +@REM Required ENV vars: +@REM JAVA_HOME - location of a JDK home dir +@REM +@REM Optional ENV vars +@REM M2_HOME - location of maven2's installed home dir +@REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands +@REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a keystroke before ending +@REM MAVEN_OPTS - parameters passed to the Java VM when running Maven +@REM e.g. to debug Maven itself, use +@REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 +@REM MAVEN_SKIP_RC - flag to disable loading of mavenrc files +@REM ---------------------------------------------------------------------------- + +@REM Begin all REM lines with '@' in case MAVEN_BATCH_ECHO is 'on' +@echo off +@REM set title of command window +title %0 +@REM enable echoing by setting MAVEN_BATCH_ECHO to 'on' +@if "%MAVEN_BATCH_ECHO%" == "on" echo %MAVEN_BATCH_ECHO% + +@REM set %HOME% to equivalent of $HOME +if "%HOME%" == "" (set "HOME=%HOMEDRIVE%%HOMEPATH%") + +@REM Execute a user defined script before this one +if not "%MAVEN_SKIP_RC%" == "" goto skipRcPre +@REM check for pre script, once with legacy .bat ending and once with .cmd ending +if exist "%HOME%\mavenrc_pre.bat" call "%HOME%\mavenrc_pre.bat" +if exist "%HOME%\mavenrc_pre.cmd" call "%HOME%\mavenrc_pre.cmd" +:skipRcPre + +@setlocal + +set ERROR_CODE=0 + +@REM To isolate internal variables from possible post scripts, we use another setlocal +@setlocal + +@REM ==== START VALIDATION ==== +if not "%JAVA_HOME%" == "" goto OkJHome + +echo. +echo Error: JAVA_HOME not found in your environment. >&2 +echo Please set the JAVA_HOME variable in your environment to match the >&2 +echo location of your Java installation. >&2 +echo. +goto error + +:OkJHome +if exist "%JAVA_HOME%\bin\java.exe" goto init + +echo. +echo Error: JAVA_HOME is set to an invalid directory. >&2 +echo JAVA_HOME = "%JAVA_HOME%" >&2 +echo Please set the JAVA_HOME variable in your environment to match the >&2 +echo location of your Java installation. >&2 +echo. +goto error + +@REM ==== END VALIDATION ==== + +:init + +@REM Find the project base dir, i.e. the directory that contains the folder ".mvn". +@REM Fallback to current working directory if not found. + +set MAVEN_PROJECTBASEDIR=%MAVEN_BASEDIR% +IF NOT "%MAVEN_PROJECTBASEDIR%"=="" goto endDetectBaseDir + +set EXEC_DIR=%CD% +set WDIR=%EXEC_DIR% +:findBaseDir +IF EXIST "%WDIR%"\.mvn goto baseDirFound +cd .. +IF "%WDIR%"=="%CD%" goto baseDirNotFound +set WDIR=%CD% +goto findBaseDir + +:baseDirFound +set MAVEN_PROJECTBASEDIR=%WDIR% +cd "%EXEC_DIR%" +goto endDetectBaseDir + +:baseDirNotFound +set MAVEN_PROJECTBASEDIR=%EXEC_DIR% +cd "%EXEC_DIR%" + +:endDetectBaseDir + +IF NOT EXIST "%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config" goto endReadAdditionalConfig + +@setlocal EnableExtensions EnableDelayedExpansion +for /F "usebackq delims=" %%a in ("%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config") do set JVM_CONFIG_MAVEN_PROPS=!JVM_CONFIG_MAVEN_PROPS! %%a +@endlocal & set JVM_CONFIG_MAVEN_PROPS=%JVM_CONFIG_MAVEN_PROPS% + +:endReadAdditionalConfig + +SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe" +set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar" +set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain + +set DOWNLOAD_URL="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" + +FOR /F "tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO ( + IF "%%A"=="wrapperUrl" SET DOWNLOAD_URL=%%B +) + +@REM Extension to allow automatically downloading the maven-wrapper.jar from Maven-central +@REM This allows using the maven wrapper in projects that prohibit checking in binary data. +if exist %WRAPPER_JAR% ( + if "%MVNW_VERBOSE%" == "true" ( + echo Found %WRAPPER_JAR% + ) +) else ( + if not "%MVNW_REPOURL%" == "" ( + SET DOWNLOAD_URL="%MVNW_REPOURL%/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" + ) + if "%MVNW_VERBOSE%" == "true" ( + echo Couldn't find %WRAPPER_JAR%, downloading it ... + echo Downloading from: %DOWNLOAD_URL% + ) + + powershell -Command "&{"^ + "$webclient = new-object System.Net.WebClient;"^ + "if (-not ([string]::IsNullOrEmpty('%MVNW_USERNAME%') -and [string]::IsNullOrEmpty('%MVNW_PASSWORD%'))) {"^ + "$webclient.Credentials = new-object System.Net.NetworkCredential('%MVNW_USERNAME%', '%MVNW_PASSWORD%');"^ + "}"^ + "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; $webclient.DownloadFile('%DOWNLOAD_URL%', '%WRAPPER_JAR%')"^ + "}" + if "%MVNW_VERBOSE%" == "true" ( + echo Finished downloading %WRAPPER_JAR% + ) +) +@REM End of extension + +@REM Provide a "standardized" way to retrieve the CLI args that will +@REM work with both Windows and non-Windows executions. +set MAVEN_CMD_LINE_ARGS=%* + +%MAVEN_JAVA_EXE% %JVM_CONFIG_MAVEN_PROPS% %MAVEN_OPTS% %MAVEN_DEBUG_OPTS% -classpath %WRAPPER_JAR% "-Dmaven.multiModuleProjectDirectory=%MAVEN_PROJECTBASEDIR%" %WRAPPER_LAUNCHER% %MAVEN_CONFIG% %* +if ERRORLEVEL 1 goto error +goto end + +:error +set ERROR_CODE=1 + +:end +@endlocal & set ERROR_CODE=%ERROR_CODE% + +if not "%MAVEN_SKIP_RC%" == "" goto skipRcPost +@REM check for post script, once with legacy .bat ending and once with .cmd ending +if exist "%HOME%\mavenrc_post.bat" call "%HOME%\mavenrc_post.bat" +if exist "%HOME%\mavenrc_post.cmd" call "%HOME%\mavenrc_post.cmd" +:skipRcPost + +@REM pause the script if MAVEN_BATCH_PAUSE is set to 'on' +if "%MAVEN_BATCH_PAUSE%" == "on" pause + +if "%MAVEN_TERMINATE_CMD%" == "on" exit %ERROR_CODE% + +exit /B %ERROR_CODE% diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..115abe5 --- /dev/null +++ b/pom.xml @@ -0,0 +1,110 @@ + + + 4.0.0 + + org.qed + rulescript-java + 1.0-SNAPSHOT + jar + + + UTF-8 + 25 + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + + ${maven.compiler.release} + + + + org.codehaus.mojo + exec-maven-plugin + 3.5.0 + + ${project.basedir} + + + + calcite-codegen-test + + org.qed.Backends.Calcite.CalciteTester + + + + cockroach-codegen + + org.qed.Backends.Cockroach.CockroachTester + + + + qed-parser-main + + org.qed.Main + + + + + + + + + + org.apache.calcite + calcite-core + 1.35.0 + + + org.apache.calcite + calcite-server + 1.34.0 + + + org.apache.calcite.avatica + avatica-core + 1.23.0 + + + com.fasterxml.jackson.core + jackson-databind + 2.15.1 + + + commons-io + commons-io + 2.11.0 + + + org.slf4j + slf4j-api + 2.0.12 + + + org.slf4j + slf4j-simple + 2.0.12 + + + org.glavo.kala + kala-common + 0.67.0 + + + io.github.p-org.solvers + cvc5 + 0.0.7-v5 + + + com.mysql + mysql-connector-j + 8.0.33 + + + diff --git a/scripts/build-qed-prover.sh b/scripts/build-qed-prover.sh new file mode 100644 index 0000000..10fe4b0 --- /dev/null +++ b/scripts/build-qed-prover.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Build qed-prover with Rust nightly + +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly +source $HOME/.cargo/env + +git clone https://github.com/qed-solver/prover.git qed-prover +cd qed-prover +cargo +nightly build --release \ No newline at end of file diff --git a/scripts/generate-rule-json.sh b/scripts/generate-rule-json.sh new file mode 100644 index 0000000..172b0e5 --- /dev/null +++ b/scripts/generate-rule-json.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Script to generate JSON files for all RRule instances + +# Create temporary Java file for JSON generation +cat > JsonGenerator.java << 'EOF' +import org.qed.*; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.nio.file.*; + +public class JsonGenerator { + public static void main(String[] args) throws Exception { + String className = args[0]; + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + ObjectMapper mapper = new ObjectMapper(); + String fileName = rule.name() + "-" + rule.info() + ".json"; + ObjectNode jsonNode = rule.toJson(); + mapper.writerWithDefaultPrettyPrinter().writeValue( + Path.of("tmp-rules", fileName).toFile(), + jsonNode + ); + } +} +EOF + +# Build classpath +MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +CLASSPATH="target/classes:${MAVEN_CP}" + +# Compile the generator +javac -cp "$CLASSPATH" JsonGenerator.java + +# Generate JSON for each rule +find src/main/java/org/qed/RRuleInstances -name '*.java' | while read file; do + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + echo "Generating JSON for: $className" + java -cp ".:$CLASSPATH" JsonGenerator "$className" +done + +# Cleanup +rm -f JsonGenerator.java JsonGenerator.class + +echo "JSON generation complete. Files are in tmp-rules/" \ No newline at end of file diff --git a/scripts/install-dependencies.sh b/scripts/install-dependencies.sh new file mode 100644 index 0000000..500879f --- /dev/null +++ b/scripts/install-dependencies.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Install required dependencies for qed-prover + +sudo apt-get update +sudo apt-get install -y jq z3 + +# Install cvc5 - try package manager first, otherwise build from source +sudo apt-get install -y cvc5 || ( + sudo apt-get install -y cmake libgmp-dev && + git clone --depth 1 https://github.com/cvc5/cvc5.git && + cd cvc5 && + ./configure.sh --auto-download && + cd build && + make -j$(nproc) && + sudo make install +) \ No newline at end of file diff --git a/scripts/test-codegen.sh b/scripts/test-codegen.sh new file mode 100644 index 0000000..064b171 --- /dev/null +++ b/scripts/test-codegen.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +# Script to generate code for each rule and test whether the rules can be applied correctly + +echo "## Code Generation Test Results" >> $GITHUB_STEP_SUMMARY +echo "" >> $GITHUB_STEP_SUMMARY + +# Step 1: Generate code for each rule in RRuleInstances +# Create temporary Java file for code generation +cat > RuleGenerator.java << 'EOF' +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.*; +import java.nio.file.*; + +public class RuleGenerator { + public static void main(String[] args) throws Exception { + String className = args[0]; + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + + CalciteTester tester = new CalciteTester(); + tester.serialize(rule, CalciteTester.genPath); + + System.out.println("Generated code for: " + rule.name()); + } +} +EOF + +# Build classpath +MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +CLASSPATH="target/classes:${MAVEN_CP}" + +# Compile the generator +javac -cp "$CLASSPATH" RuleGenerator.java + +# Generate code for each rule +find src/main/java/org/qed/RRuleInstances -name '*.java' -not -path '*/RRuleInstances-unprovable/*' | while read file; do + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + java -cp ".:$CLASSPATH" RuleGenerator "$className" +done + +# Step 2: Check for missing tests +for rule_file in src/main/java/org/qed/RRuleInstances/*.java; do + rule_name=$(basename "$rule_file" .java) + if [ ! -f "src/main/java/org/qed/Backends/Calcite/Tests/${rule_name}Test.java" ]; then + missing_tests="${missing_tests}- ${rule_name}\n" + missing_count=$((missing_count + 1)) + fi +done + +if [ $missing_count -gt 0 ]; then + echo "**⚠️ Warning: Missing tests for $missing_count rules:**" >> $GITHUB_STEP_SUMMARY + echo -e "$missing_tests" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY +fi + +# Step 3: Run all test classes +# Store results for summary +total_tests=0 +passed_tests=0 + +# Find all test files and run them +total_tests=0 +passed_tests=0 + +for test_file in src/main/java/org/qed/Backends/Calcite/Tests/*Test.java; do + class_name=${test_file#src/main/java/} + class_name=${class_name%.java} + class_name=${class_name//\//.} + test_name=$(basename "$test_file" .java) + display_name=${test_name%Test} + total_tests=$((total_tests + 1)) + + # Run the test and capture output + if java -cp "$CLASSPATH" "$class_name" > /tmp/test_output.txt 2>&1; then + if grep -q "trivial" /tmp/test_output.txt; then + echo "⚠️ ${display_name}: TRIVIAL" >> $GITHUB_STEP_SUMMARY + elif grep -q "succeeded" /tmp/test_output.txt && ! grep -q "failed" /tmp/test_output.txt; then + echo "✅ ${display_name}: PASSED" >> $GITHUB_STEP_SUMMARY + passed_tests=$((passed_tests + 1)) + else + echo "❌ ${display_name}: FAILED" >> $GITHUB_STEP_SUMMARY + fi + else + echo "❌ ${display_name}: ERROR" >> $GITHUB_STEP_SUMMARY + fi +done + +# Clean up +rm -f RuleGenerator.java RuleGenerator.class /tmp/test_output.txt + +echo "" >> $GITHUB_STEP_SUMMARY +echo "**Summary:** $passed_tests/$total_tests passed" >> $GITHUB_STEP_SUMMARY + +# Exit with error if tests failed +if [ "$passed_tests" -ne "$total_tests" ]; then + exit 1 +fi \ No newline at end of file diff --git a/scripts/test-rules.sh b/scripts/test-rules.sh new file mode 100644 index 0000000..5dc149b --- /dev/null +++ b/scripts/test-rules.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Test all generated rules with qed-prover + +echo "## QED Prover Test Results" >> $GITHUB_STEP_SUMMARY +echo "" >> $GITHUB_STEP_SUMMARY + +failed_rules="" +total_count=0 +passed_count=0 + +for json_file in tmp-rules/*.json; do + rule_name=$(basename "$json_file" .json) + total_count=$((total_count + 1)) + ./qed-prover/target/release/qed-prover "$json_file" || true + + result_file="${json_file%.json}.result" + if [ -f "$result_file" ] && jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then + echo "✅ $rule_name: PASSED" >> $GITHUB_STEP_SUMMARY + passed_count=$((passed_count + 1)) + else + echo "❌ $rule_name: FAILED" >> $GITHUB_STEP_SUMMARY + failed_rules="$failed_rules$rule_name," + fi +done + +echo "" >> $GITHUB_STEP_SUMMARY +echo "**Summary:** $passed_count/$total_count passed" >> $GITHUB_STEP_SUMMARY + +if [ -n "$failed_rules" ]; then + echo "::error::Failed rules: ${failed_rules%,}" + exit 1 +fi \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/CalciteGenerator.java b/src/main/java/org/qed/Backends/Calcite/CalciteGenerator.java new file mode 100644 index 0000000..da75d24 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/CalciteGenerator.java @@ -0,0 +1,518 @@ +package org.qed.Backends.Calcite; + +import kala.collection.Seq; +import kala.collection.immutable.ImmutableMap; +import kala.tuple.Tuple; +import kala.tuple.Tuple2; +import org.qed.CodeGenerator; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.Backends.Calcite.CalciteGenerator.Env; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.processing.Generated; + +public class CalciteGenerator implements CodeGenerator { + + @Override + public Env preMatch(String rulename) { + return Env.empty(rulename); + } + + @Override + public Env preTransform(Env env) { + var buildEnv = env.declare("call.builder()"); + return buildEnv.getValue().focus(buildEnv.getKey()); + } + + @Override + public Env postTransform(Env env) { + return env.state("call.transformTo(" + env.current() + ".build());"); + } + + @Override + public String translate(String name, Env onMatch, Env transform) { + var builder = new StringBuilder("package org.qed.Backends.Calcite.Generated;\n\n"); + builder.append("import org.apache.calcite.plan.RelOptRuleCall;\n"); + builder.append("import org.apache.calcite.plan.RelRule;\n"); + builder.append("import org.apache.calcite.plan.RelOptUtil;\n"); + builder.append("import org.apache.calcite.rel.RelNode;\n"); + builder.append("import org.apache.calcite.rel.core.JoinRelType;\n"); + builder.append("import org.apache.calcite.rel.logical.*;\n"); + builder.append("import org.qed.Backends.Calcite.EmptyConfig;\n"); + builder.append("\n"); + builder.append("public class " + name + " extends RelRule<" + name + ".Config> {\n"); + builder.append("\tprotected " + name + "(Config config) {\n"); + builder.append("\t\tsuper(config);\n"); + builder.append("\t}\n\n"); + builder.append("\t@Override\n\tpublic void onMatch(RelOptRuleCall call) {\n"); + transform.statements().forEach(statement -> builder.append("\t\t").append(statement).append("\n")); + builder.append("\t}\n\n"); + builder.append("\tpublic interface Config extends EmptyConfig {\n"); + builder.append("\t\tConfig DEFAULT = new Config() {};\n\n"); + builder.append("\t\t@Override\n\t\tdefault " + name + " toRule() {\n"); + builder.append("\t\t\treturn new " + name + "(this);\n"); + builder.append("\t\t}\n\n"); + builder.append("\t\t@Override\n\t\tdefault String description() {\n"); + builder.append("\t\t\treturn \"" + name + "\";\n"); + builder.append("\t\t}\n\n"); + builder.append("\t\t@Override\n\t\tdefault RelRule.OperandTransform operandSupplier() {\n"); + builder.append("\t\t\treturn " + onMatch.skeleton() + ";\n"); + builder.append("\t\t}\n\n"); + builder.append("\t}\n"); + builder.append("}\n"); + return builder.toString(); + } + + @Override + public Env onMatchScan(Env env, RelRN.Scan scan) { + return env.symbol(scan.name(), env.current()).grow("operand(RelNode.class).anyInputs()"); + } + + @Override + public Env onMatchFilter(Env env, RelRN.Filter filter) { + var source_match = onMatch(env.next(), filter.source()); + var operator_match = source_match.grow("operand(LogicalFilter.class).oneInput(" + source_match.skeleton() + ")"); + var condition_match = operator_match.focus("((LogicalFilter) " + env.current() + ").getCondition()"); + return onMatch(condition_match, filter.cond()); + } + + @Override + public Env onMatchProject(Env env, RelRN.Project project) { + var source_match = onMatch(env.next(), project.source()); + var operator_match = + source_match.grow("operand(LogicalProject.class).oneInput(" + source_match.skeleton() + ")"); + var map_match = operator_match.focus("((LogicalProject) " + env.current() + ").getProjects()"); + return onMatch(map_match, project.map()); + } + + @Override + public Env onMatchPred(Env env, RexRN.Pred pred) { + return env.symbol(pred.operator().getName(), env.current()); + } + + @Override + public Env onMatchProj(Env env, RexRN.Proj proj) { + return env.symbol(proj.operator().getName(), env.current()); + } + + @Override + public Env onMatchJoinWithSeparateConds(Env env, RelRN.JoinWithSeparateConds join) { + var current_join = "((LogicalJoin) " + env.current() + ")"; + var left_source_env = env.next(); + var left_match_env = onMatch(left_source_env, join.left()); + var right_source_env = left_match_env.next(); + var right_match_env = onMatch(right_source_env, join.right()); + var operator_match = + right_match_env.grow("operand(LogicalJoin.class).inputs(" + left_match_env.skeleton() + ", " + right_match_env.skeleton() + ")"); + var cond_source_env = operator_match.focus(current_join + ".getCondition()"); + return onMatch(cond_source_env, join.cond()); + } + + @Override + public Env onMatchJoin(Env env, RelRN.Join join) { + var current_join = "((LogicalJoin) " + env.current() + ")"; + var left_source_env = env.next(); + var left_match_env = onMatch(left_source_env, join.left()); + var right_source_env = left_match_env.next(); + var right_match_env = onMatch(right_source_env, join.right()); + var operator_match = + right_match_env.grow("operand(LogicalJoin.class).inputs(" + left_match_env.skeleton() + ", " + right_match_env.skeleton() + ")"); + var cond_source_env = operator_match.focus(current_join + ".getCondition()"); + return onMatch(cond_source_env, join.cond()); + } + + @Override + public Env onMatchAnd(Env env, RexRN.And and) { + var current_env = env; + String andSymbol = "and_" + env.varId.getAndIncrement(); + current_env = current_env.symbol(andSymbol, current_env.current()); + for (var source : and.sources()) { + current_env = onMatch(current_env, source); + } + return current_env; + } + + @Override + public Env onMatchUnion(Env env, RelRN.Union union) { + boolean all = union.all(); + var current_env = env; + var skeletons = Seq.empty(); + for (var source : union.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + String operatorClass = all ? "LogicalUnionAll" : "LogicalUnion"; + return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { + boolean all = intersect.all(); + var current_env = env; + var skeletons = Seq.empty(); + for (var source : intersect.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + String operatorClass = all ? "LogicalIntersectAll" : "LogicalIntersect"; + return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchMinus(Env env, RelRN.Minus minus) { + boolean all = minus.all(); + var current_env = env; + var skeletons = Seq.empty(); + for (var source : minus.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + return current_env.grow("operand(LogicalMinus.class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchField(Env env, RexRN.Field field) { + String fieldSymbol = "field_" + env.varId.getAndIncrement(); + return env.symbol(fieldSymbol, env.current()); + } + + @Override + public Env onMatchTrue(Env env, RexRN literal) { + String trueSymbol = "true_" + env.varId.getAndIncrement(); + return env.symbol(trueSymbol, env.current()); + } + + @Override + public Env onMatchFalse(Env env, RexRN literal) { + String falseSymbol = "false_" + env.varId.getAndIncrement(); + return env.symbol(falseSymbol, env.current()); + } + + @Override + public Env onMatchEmpty(Env env, RelRN.Empty empty) { + return env.grow("operand(LogicalValues.class).noInputs()"); + } + + @Override + public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + var sourceMatch = onMatch(env.next(), aggregate.source()); + return sourceMatch.grow("operand(LogicalAggregate.class).oneInput(" + sourceMatch.skeleton() + ")"); + } + + @Override + public Env onMatchCustom(Env env, RelRN custom) { + if (env.rulename.equals("AggregateProjectConstantToDummyJoin")) { + return switch (custom) { + case org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin.SourceTable st -> env.next().grow("operand(RelNode.class).anyInputs()"); + case org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin.ProjectWithConstantLiterals p -> { var sourceMatch = onMatch(env, p.input()); yield sourceMatch.grow("operand(LogicalProject.class).oneInput(" + sourceMatch.skeleton() + ")");} + case org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin.AggregateGroupingByConstants agg -> { var sourceMatch = onMatch(env, agg.input()); yield sourceMatch.grow("operand(LogicalAggregate.class).oneInput(" + sourceMatch.skeleton() + ")");} + default -> env; + }; + } + if (env.rulename.equals("UnionToDistinct")) { + return switch (custom) { + case org.qed.RRuleInstances.UnionToDistinct.DistinctUnion u -> { var leftMatch = onMatch(env.next(), u.left()); var rightMatch = onMatch(leftMatch.next(), u.right()); yield rightMatch.grow("operand(LogicalUnion.class)" + ".predicate(union -> !union.all)" + ".anyInputs()");} + default -> env; + }; + } + if (env.rulename.equals("UnionPullUpConstants")) { + return switch (custom) { + case org.qed.RRuleInstances.UnionPullUpConstants.UnionWithConstantColumns u -> { var leftMatch = onMatch(env.next(), u.left()); var rightMatch = onMatch(leftMatch.next(), u.right()); yield rightMatch.grow("operand(LogicalUnion.class)" + ".predicate(union -> union.getRowType().getFieldCount() > 1)" + ".anyInputs()");} + case org.qed.RRuleInstances.UnionPullUpConstants.LeftProjectionWithConstants left -> { var sourceMatch = onMatch(env, left.input()); yield sourceMatch.grow("operand(LogicalProject.class).oneInput(" + sourceMatch.skeleton() + ")"); } + case org.qed.RRuleInstances.UnionPullUpConstants.RightProjectionWithConstants right -> { var sourceMatch = onMatch(env, right.input()); yield sourceMatch.grow("operand(LogicalProject.class).oneInput(" + sourceMatch.skeleton() + ")");} + default -> env; + }; + } + if (env.rulename.equals("ProjectAggregateMerge")) { + return switch (custom) { + case org.qed.RRuleInstances.ProjectAggregateMerge.ProjectUsingSubsetOfAggregates p -> { var sourceMatch = onMatch(env, p.input()); yield sourceMatch.grow("operand(LogicalProject.class).oneInput(" + sourceMatch.skeleton() + ")");} + case org.qed.RRuleInstances.ProjectAggregateMerge.AggregateWithMultipleCalls a -> { var sourceMatch = onMatch(env, a.input()); yield sourceMatch.grow("operand(LogicalAggregate.class).oneInput(" + sourceMatch.skeleton() + ")");} + case org.qed.RRuleInstances.ProjectAggregateMerge.SourceTable st -> { yield env.next().grow("operand(RelNode.class).anyInputs()");} + default -> env; + }; + } + return CodeGenerator.super.onMatchCustom(env, custom); + } + + @Override + public Env transformScan(Env env, RelRN.Scan scan) { + return env.focus(env.current() + ".push(" + env.symbols().get(scan.name()) + ")"); + } + + @Override + public Env transformFilter(Env env, RelRN.Filter filter) { + var source_transform = transform(env, filter.source()); + var source_expression = source_transform.current(); + var cond_transform = transform(source_transform, filter.cond()); + return cond_transform.focus(source_expression + ".filter(" + cond_transform.current() + ")"); + } + + @Override + public Env transformPred(Env env, RexRN.Pred pred) { + if (env.rulename.equals("JoinCommute")) { + var currentEnv = env; var transformedArgs = Seq.empty(); var sources = pred.sources(); var reversedSources = Seq.of(sources.get(1), sources.get(0)); + for (var arg : reversedSources) { + currentEnv = transform(currentEnv, arg); transformedArgs = transformedArgs.appended(currentEnv.current()); currentEnv = currentEnv.focus(env.current()); + } + String argsString = transformedArgs.joinToString(", "); + String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()"; + return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")"); + } + if (env.rulename.equals("ProjectFilterTranspose")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.mapFilterToProjectedColumns(call)");} + if (env.rulename.equals("FilterProjectTranspose")) {return env.focus("RelOptUtil.pushFilterPastProject(((LogicalFilter) call.rel(0)).getCondition(), " + "((LogicalProject) call.rel(1)))");} + if (env.rulename.equals("AggregateFilterTranspose")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.mapFilterToAggregatedColumns(call)");} + if (env.rulename.equals("FilterAggregateTranspose")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.pushFilterPastAggregate(call)");} + return env.focus(env.symbols().get(pred.operator().getName())); + } + + @Override + public Env transformJoinField(Env env, RexRN.JoinField joinField) { + var origJoinDecl = env.declare("(LogicalJoin) call.rel(0)"); + var envWithOrigJoin = origJoinDecl.getValue(); + var conditionDecl = envWithOrigJoin.declare("(org.apache.calcite.rex.RexCall) " + origJoinDecl.getKey() + ".getCondition()"); + var envWithCondition = conditionDecl.getValue(); + if (joinField.ordinal() == 0) { + var leftFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(0)).getIndex()"); + var envWithLeftField = leftFieldDecl.getValue(); + return envWithLeftField.focus(env.current() + ".field(2, 1, " + leftFieldDecl.getKey() + ")"); + } + else if (joinField.ordinal() == 1) { + var rightFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(1)).getIndex()"); + var envWithRightField = rightFieldDecl.getValue(); + var leftColCountDecl = envWithRightField.declare("call.rel(1).getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var adjustedRightFieldDecl = envWithLeftCount.declare(rightFieldDecl.getKey() + " - " + leftColCountDecl.getKey()); + var envWithAdjustedRightField = adjustedRightFieldDecl.getValue(); + return envWithAdjustedRightField.focus(env.current() + ".field(2, 0, " + adjustedRightFieldDecl.getKey() + ")"); + } else { + throw new UnsupportedOperationException("Unsupported join field ordinal: " + joinField.ordinal()); + } + } + + @Override + public Env transformJoinWithPushedConds(Env env, RelRN.JoinWithPushedConds join) { + var builderDecl = env.declare("call.builder()"); + var envWithBuilder = builderDecl.getValue(); + var leftCondDecl = envWithBuilder.declare("org.qed.Backends.Calcite.HelperFunctions.ConditionDecomposer.extractLeftOnlyConditions(" + "((LogicalJoin) call.rel(0)).getCondition(), " + "call.rel(1).getRowType().getFieldCount(), call)"); + var envWithLeftCond = leftCondDecl.getValue(); + var rightCondDecl = envWithLeftCond.declare("org.qed.Backends.Calcite.HelperFunctions.ConditionDecomposer.extractRightOnlyConditions(" + "((LogicalJoin) call.rel(0)).getCondition(), " + "call.rel(1).getRowType().getFieldCount(), " + "call.rel(1).getRowType().getFieldCount() + call.rel(2).getRowType().getFieldCount(), call)"); + var envWithRightCond = rightCondDecl.getValue(); + var joinCondDecl = envWithRightCond.declare("org.qed.Backends.Calcite.HelperFunctions.ConditionDecomposer.extractJoinConditions(" + "((LogicalJoin) call.rel(0)).getCondition(), " + "call.rel(1).getRowType().getFieldCount(), " + "call.rel(1).getRowType().getFieldCount() + call.rel(2).getRowType().getFieldCount(), call)"); + var envWithJoinCond = joinCondDecl.getValue(); + return envWithJoinCond.focus(builderDecl.getKey() + ".push(call.rel(1))" + ".filter(" + leftCondDecl.getKey() + ")" + ".push(call.rel(2))" + ".filter(" + rightCondDecl.getKey() + ")" + ".join(JoinRelType.INNER, " + joinCondDecl.getKey() + ")"); + } + + @Override + public Env transformJoin(Env env, RelRN.Join join) { + var left_source_transform = transform(env, join.left()); + var right_source_transform = transform(left_source_transform, join.right()); + var source_expression = right_source_transform.current(); + var cond_transform = transform(right_source_transform, join.cond()); + var join_type = switch (join.ty().semantics()) { + case INNER -> "JoinRelType.INNER"; + case LEFT -> "JoinRelType.LEFT"; + case RIGHT -> "JoinRelType.RIGHT"; + case FULL -> "JoinRelType.FULL"; + case SEMI -> "JoinRelType.SEMI"; + case ANTI -> "JoinRelType.ANTI"; + }; + return cond_transform.focus(source_expression + ".join(" + join_type + ", " + cond_transform.current() + ")"); + } + + @Override + public Env transformAnd(Env env, RexRN.And and) { + var source_transform = env; + var operands = Seq.empty(); + for (var source : and.sources()) { + source_transform = transform(source_transform, source); + operands = operands.appended(source_transform.current()); + source_transform = source_transform.focus(env.current()); + } + return source_transform.focus(env.current() + ".and(" + operands.joinToString(", ") + ")"); + } + + @Override + public Env transformUnion(Env env, RelRN.Union union) { + boolean all = union.all(); + int sourceCount = union.sources().size(); + var current_env = env; + for (var source : union.sources()) { + current_env = transform(current_env, source); + } + return current_env.focus(current_env.current() + ".union(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformIntersect(Env env, RelRN.Intersect intersect) { + if (env.rulename.equals("PruneEmptyIntersect")) { + String builderVar = env.statements().get(0).split(" ")[1]; + return env.focus(builderVar + ".push(call.rel(1)).empty()" + ".push(call.rel(2))" + ".intersect(false, 2)"); + } + boolean all = intersect.all(); + int sourceCount = intersect.sources().size(); + var current_env = env; + for (var source : intersect.sources()) { + current_env = transform(current_env, source); + } + String methodName = all ? "intersectAll" : "intersect"; + return current_env.focus(current_env.current() + "." + methodName + "(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformMinus(Env env, RelRN.Minus minus) { + boolean all = minus.all(); + int sourceCount = minus.sources().size(); + var current_env = env; + for (var source : minus.sources()) { + current_env = transform(current_env, source); + } + return current_env.focus(current_env.current() + ".minus(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformField(Env env, RexRN.Field field) { + return env.focus(env.current() + ".field(" + field + ")"); + } + + @Override + public Env transformProj(Env env, RexRN.Proj proj) { + if (!env.symbols().containsKey(proj.operator().getName())) { + throw new RuntimeException("Operator symbol not found: " + proj.operator().getName() + ". Make sure onMatchProj is properly implemented."); + } + return env.focus(env.symbols().get(proj.operator().getName())); + } + + @Override + public Env transformProject(Env env, RelRN.Project project) { + if (env.rulename.equals("ProjectMerge")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.mergeProjections(call)");} + var source_transform = transform(env, project.source()); + var source_expression = source_transform.current(); + var map_transform = transform(source_transform, project.map()); + return map_transform.focus(source_expression + ".project(" + map_transform.current() + ")"); + } + + @Override + public Env transformTrue(Env env, RexRN literal) { + return env.focus(env.current() + ".literal(true)"); + } + + @Override + public Env transformFalse(Env env, RexRN literal) { + return env.focus(env.current() + ".literal(false)"); + } + + @Override + public Env transformEmpty(Env env, RelRN.Empty empty) { + return env.focus(env.current() + ".empty()"); + } + + @Override + public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + if (env.rulename.equals("AggregateProjectMerge")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.createMergedAggregateProject(call)");} + if (env.rulename.equals("AggregateExtractProject")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.extractProjectForAggregate(call)");} + if (env.rulename.equals("AggregateJoinRemove")) { + var groupSetDecl = env.declare("((LogicalAggregate) call.rel(0)).getGroupSet()"); + var envWithGroupSet = groupSetDecl.getValue(); + var aggCallsDecl = envWithGroupSet.declare("((LogicalAggregate) call.rel(0)).getAggCallList()"); + var envWithAggCalls = aggCallsDecl.getValue(); + String builderVar = env.statements().get(0).split(" ")[1]; + return envWithAggCalls.focus(builderVar + ".push(call.rel(3)).push(call.rel(4))" + ".join(JoinRelType.INNER, " + builderVar + ".literal(true))" + ".aggregate(" + builderVar + ".groupKey(" + groupSetDecl.getKey() + "), " + aggCallsDecl.getKey() + ")"); + } + if (env.rulename.equals("AggregateJoinJoinRemove")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.aggregateJoinJoinRemove(call)");} + + var sourceTransform = transform(env, aggregate.source()); + String builderWithSource = sourceTransform.current(); + String originalAgg; + if (env.rulename.equals("FilterAggregateTranspose")) {originalAgg = "((LogicalAggregate) call.rel(1))";} + else originalAgg = "((LogicalAggregate) call.rel(0))"; + var groupSetDecl = sourceTransform.declare(originalAgg + ".getGroupSet()"); + var envWithGroupSet = groupSetDecl.getValue(); + var groupKeyDecl = envWithGroupSet.declare(builderWithSource + ".groupKey(" + groupSetDecl.getKey() + ")"); + var envWithGroupKey = groupKeyDecl.getValue(); + var aggCallsDecl = envWithGroupKey.declare(originalAgg + ".getAggCallList()"); + var envWithAggCalls = aggCallsDecl.getValue(); + return envWithAggCalls.focus(builderWithSource + ".aggregate(" + groupKeyDecl.getKey() + ", " + aggCallsDecl.getKey() + ")"); + } + + @Override + public Env transformCustom(Env env, RelRN custom) { + if (env.rulename.equals("AggregateProjectConstantToDummyJoin")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.aggregateProjectConstantToDummyJoin(call)");} + if (env.rulename.equals("UnionToDistinct")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.unionToDistinct(call)");} + if (env.rulename.equals("UnionPullUpConstants")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.unionPullUpConstants(call)");} + if (env.rulename.equals("ProjectAggregateMerge")) {return env.focus("org.qed.Backends.Calcite.HelperFunctions.projectAggregateMerge(call)");} + return switch (custom) { + case org.qed.RRuleInstances.JoinCommute.ProjectionRelRN projection -> { + var sourceEnv = transform(env, projection.source()); + var leftTableDecl = sourceEnv.declare("call.rel(1)"); + var envWithLeftTable = leftTableDecl.getValue(); + var rightTableDecl = envWithLeftTable.declare("call.rel(2)"); + var envWithRightTable = rightTableDecl.getValue(); + var leftColCountDecl = envWithRightTable.declare(leftTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var rightColCountDecl = envWithLeftCount.declare(rightTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithRightCount = rightColCountDecl.getValue(); + var projectionIndicesDecl = envWithRightCount.declare("java.util.stream.IntStream.concat(" + "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + rightColCountDecl.getKey() + " + " + leftColCountDecl.getKey() + "), " + "java.util.stream.IntStream.range(0, " + rightColCountDecl.getKey() + ")" + ").boxed().collect(java.util.stream.Collectors.toList())"); + var envWithProjectionIndices = projectionIndicesDecl.getValue(); + var fieldRefsDecl = envWithProjectionIndices.declare(sourceEnv.current() + ".fields(" + projectionIndicesDecl.getKey() + ")"); + var envWithFieldRefs = fieldRefsDecl.getValue(); + yield envWithFieldRefs.focus(sourceEnv.current() + ".project(" + fieldRefsDecl.getKey() + ")"); + } + default -> unimplementedTransform(env, custom); + }; + } + + public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq statements, + ImmutableMap symbols, String rulename) { + public static Env empty(String rulename) { + return new Env(new AtomicInteger(), 0, "call.rel(0)", "/* Unspecified skeleton */", Seq.empty(), + ImmutableMap.empty(), rulename); + } + public Env next() { + return new Env(varId, rel + 1, "call.rel(" + (rel + 1) + ")", skeleton, statements, symbols, rulename); + } + public Env focus(String target) { + return new Env(varId, rel, target, skeleton, statements, symbols, rulename); + } + public Env state(String statement) { + return new Env(varId, rel, current, skeleton, statements.appended(statement), symbols, rulename); + } + public Env symbol(String symbol, String expression) { + return new Env(varId, rel, current, skeleton, statements, symbols.putted(symbol, expression), rulename); + } + public Tuple2 declare(String expression) { + var name = "var_" + varId.getAndIncrement(); + return Tuple.of(name, state("var " + name + " = " + expression + ";")); + } + public Env grow(String requirement) { + var vn = "s_" + varId.getAndIncrement(); + return new Env(varId, rel, current, vn + " -> " + vn + "." + requirement, statements, symbols, rulename); + } + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/CalciteTester.java b/src/main/java/org/qed/Backends/Calcite/CalciteTester.java new file mode 100644 index 0000000..1a07b5b --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/CalciteTester.java @@ -0,0 +1,180 @@ +package org.qed.Backends.Calcite; + +import kala.tuple.Tuple; +import kala.collection.Seq; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.qed.*; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.jdbc.CalcitePrepare.SparkHandler.RuleSetBuilder; +import org.apache.calcite.rel.rules.*; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalValues; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgramBuilder; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Set; +import java.util.stream.Collectors; + +public class CalciteTester { + public static String genPath = + ProjectPaths.baseDir().resolve("src/main/java/org/qed/Backends/Calcite/Generated").toString(); + public static String rulePath = ProjectPaths.baseDir().resolve("rules").toString(); + + public static HepPlanner loadRules(java.util.List rules) { + System.out.printf("Loading Rules: %s\n", + rules.stream() + .map(rule -> rule.getClass().getSimpleName()) + .collect(java.util.stream.Collectors.joining(", "))); + + var builder = new HepProgramBuilder(); + for (var rule : rules) { + builder.addRuleInstance(rule); + } + return new HepPlanner(builder.build()); + } + + public static HepPlanner loadRules(RelOptRule... rules) { + return loadRules(java.util.Arrays.asList(rules)); + } + + public static HepPlanner loadRule(RelOptRule rule) { + System.out.printf("Loading Rule: %s\n", rule.getClass().getSimpleName()); + var builder = new HepProgramBuilder().addRuleInstance(rule); + return new HepPlanner(builder.build()); + } + + public static HepPlanner loadRule(RelOptRule rule, int matchLimit) { + System.out.printf("Loading Rule: %s (match limit: %d)\n", rule.getClass().getSimpleName(), matchLimit); + var builder = new HepProgramBuilder() + .addMatchLimit(matchLimit) + .addRuleInstance(rule); + return new HepPlanner(builder.build()); + } + + public static Seq ruleList() { + java.io.File ruleDir = + ProjectPaths.baseDir().resolve("src/main/java/org/qed/RRuleInstances").toFile(); + java.io.File[] files = ruleDir.listFiles((dir, name) -> name.endsWith(".java")); + + java.util.List rules = new java.util.ArrayList<>(); + + if (files != null) { + for (java.io.File file : files) { + String className = file.getName().replace(".java", ""); + + try { + Class clazz = Class.forName("org.qed.RRuleInstances." + className); + RRule rule = (RRule) clazz.getConstructor().newInstance(); + rules.add(rule); + } catch (Exception e) { + throw new RuntimeException("Failed to load rule: " + className, e); + } + } + } + + return Seq.from(rules); + + // var families = Seq.from(reflections.getSubTypesOf(RRule.RRuleFamily.class)) + // .filter(clazz -> !clazz.isInterface() && !Modifier.isAbstract(clazz.getModifiers())) + // .mapUnchecked(clazz -> { + // Constructor constructor = clazz.getDeclaredConstructor(); + // constructor.setAccessible(true); + // return constructor.newInstance(); + // }) + // .map(r -> (RRule.RRuleFamily) r); + + // return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); + } + + public static void verify() { + ruleList().forEachUnchecked(rule -> rule.dump(rulePath + "/" + rule.name() + ".json")); + } + + public static void generate() { + var tester = new CalciteTester(); + ruleList().forEach(r -> tester.serialize(r, genPath)); + } + + public static void runAllTests() { + String packagePath = + ProjectPaths.baseDir().resolve("src/main/java/org/qed/Backends/Calcite/Tests").toString(); + java.io.File testDir = new java.io.File(packagePath); + java.io.File[] testFiles = testDir.listFiles((dir, name) -> name.endsWith("Test.java")); + if (testFiles != null) { + for (java.io.File testFile : testFiles) { + String className = "org.qed.Backends.Calcite.Tests." + testFile.getName().replace(".java", ""); + try { + Class testClass = Class.forName(className); + testClass.getMethod("runTest").invoke(null); + } catch (Exception e) { + throw new RuntimeException("Failed to run test: " + className, e); + } + } + } + } + + public static void main(String[] args) throws IOException { + // var rule = new org.qed.RRuleInstances.AggregateExtractProject(); + // System.out.println(rule.explain()); + // Files.createDirectories(Path.of(rulePath)); + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // var rules = new RRuleInstance.JoinAssociate(); + // Files.createDirectories(Path.of(rulePath)); + // for (var rule : rules.family()) { + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // } + generate(); + runAllTests(); + } + + public void serialize(RRule rule, String path) { + var generator = new CalciteGenerator(); + var code_gen = generator.generate(rule); + try { + Files.write(Path.of(path, rule.name() + ".java"), code_gen.getBytes()); + } catch (IOException ioe) { + System.err.println(ioe.getMessage()); + } + } + + public void test(RelOptRule rule, Seq tests) { + System.out.println("Testing rule " + rule.getClass().getSimpleName()); + var runner = loadRule(rule); + var exams = tests.mapUnchecked(t -> Tuple.of(t, JSONDeserializer.load(new File(t)))); + for (var entry : exams) { + if (entry.getValue().size() != 2) { + System.err.println(entry.getKey() + " does not have exactly two nodes, and thus is not a valid test"); + continue; + } + verify(runner, entry.getValue().get(0), entry.getValue().get(1)); + } + } + + public void verify(HepPlanner runner, RelNode source, RelNode target) { + runner.setRoot(source); + var answer = runner.findBestExp(); + + String answerExplain = answer.explain(); + String targetExplain = target.explain(); + + if(answerExplain.equals(targetExplain)) { + System.out.println("succeeded"); + return; + } + System.out.println("failed"); + System.out.println("> Given source RelNode:\n" + source.explain()); + System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/EmptyConfig.java b/src/main/java/org/qed/Backends/Calcite/EmptyConfig.java new file mode 100644 index 0000000..d981492 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/EmptyConfig.java @@ -0,0 +1,33 @@ +package org.qed.Backends.Calcite; + +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.tools.RelBuilderFactory; +import org.checkerframework.checker.nullness.qual.Nullable; + +public interface EmptyConfig extends RelRule.Config { + @Override + default RelRule.Config withRelBuilderFactory(RelBuilderFactory factory) { + return this; + } + + @Override + default @Nullable String description() { + return "Unspecified Config Description"; + } + + @Override + default RelRule.Config withDescription(@Nullable String description) { + return this; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s -> s.operand(RelNode.class).anyInputs(); + } + + @Override + default RelRule.Config withOperandSupplier(RelRule.OperandTransform transform) { + return this; + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/AggregateExtractProject.java b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateExtractProject.java new file mode 100644 index 0000000..0919a20 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateExtractProject.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class AggregateExtractProject extends RelRule { + protected AggregateExtractProject(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.extractProjectForAggregate(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default AggregateExtractProject toRule() { + return new AggregateExtractProject(this); + } + + @Override + default String description() { + return "AggregateExtractProject"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalAggregate.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/AggregateFilterTranspose.java b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateFilterTranspose.java new file mode 100644 index 0000000..129f702 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateFilterTranspose.java @@ -0,0 +1,44 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class AggregateFilterTranspose extends RelRule { + protected AggregateFilterTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + var var_4 = ((LogicalAggregate) call.rel(0)).getGroupSet(); + var var_5 = var_3.push(call.rel(2)).groupKey(var_4); + var var_6 = ((LogicalAggregate) call.rel(0)).getAggCallList(); + call.transformTo(var_3.push(call.rel(2)).aggregate(var_5, var_6).filter(org.qed.Backends.Calcite.HelperFunctions.mapFilterToAggregatedColumns(call)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default AggregateFilterTranspose toRule() { + return new AggregateFilterTranspose(this); + } + + @Override + default String description() { + return "AggregateFilterTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalAggregate.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinJoinRemove.java b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinJoinRemove.java new file mode 100644 index 0000000..c07b1e7 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinJoinRemove.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class AggregateJoinJoinRemove extends RelRule { + protected AggregateJoinJoinRemove(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_15 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.aggregateJoinJoinRemove(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default AggregateJoinJoinRemove toRule() { + return new AggregateJoinJoinRemove(this); + } + + @Override + default String description() { + return "AggregateJoinJoinRemove"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_14 -> s_14.operand(LogicalAggregate.class).oneInput(s_13 -> s_13.operand(LogicalJoin.class).inputs(s_8 -> s_8.operand(LogicalJoin.class).inputs(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_6 -> s_6.operand(LogicalJoin.class).inputs(s_4 -> s_4.operand(RelNode.class).anyInputs(), s_5 -> s_5.operand(RelNode.class).anyInputs())), s_11 -> s_11.operand(LogicalJoin.class).inputs(s_9 -> s_9.operand(RelNode.class).anyInputs(), s_10 -> s_10.operand(RelNode.class).anyInputs()))); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinRemove.java b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinRemove.java new file mode 100644 index 0000000..575e1bd --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateJoinRemove.java @@ -0,0 +1,43 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class AggregateJoinRemove extends RelRule { + protected AggregateJoinRemove(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_10 = call.builder(); + var var_11 = ((LogicalAggregate) call.rel(0)).getGroupSet(); + var var_12 = ((LogicalAggregate) call.rel(0)).getAggCallList(); + call.transformTo(var_10.push(call.rel(3)).push(call.rel(4)).join(JoinRelType.INNER, var_10.literal(true)).aggregate(var_10.groupKey(var_11), var_12).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default AggregateJoinRemove toRule() { + return new AggregateJoinRemove(this); + } + + @Override + default String description() { + return "AggregateJoinRemove"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_9 -> s_9.operand(LogicalAggregate.class).oneInput(s_8 -> s_8.operand(LogicalJoin.class).inputs(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_6 -> s_6.operand(LogicalJoin.class).inputs(s_4 -> s_4.operand(RelNode.class).anyInputs(), s_5 -> s_5.operand(RelNode.class).anyInputs()))); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectConstantToDummyJoin.java b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectConstantToDummyJoin.java new file mode 100644 index 0000000..9fb218d --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectConstantToDummyJoin.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class AggregateProjectConstantToDummyJoin extends RelRule { + protected AggregateProjectConstantToDummyJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.aggregateProjectConstantToDummyJoin(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default AggregateProjectConstantToDummyJoin toRule() { + return new AggregateProjectConstantToDummyJoin(this); + } + + @Override + default String description() { + return "AggregateProjectConstantToDummyJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalAggregate.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectMerge.java b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectMerge.java new file mode 100644 index 0000000..f43130b --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/AggregateProjectMerge.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class AggregateProjectMerge extends RelRule { + protected AggregateProjectMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.createMergedAggregateProject(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default AggregateProjectMerge toRule() { + return new AggregateProjectMerge(this); + } + + @Override + default String description() { + return "AggregateProjectMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalAggregate.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/FilterAggregateTranspose.java b/src/main/java/org/qed/Backends/Calcite/Generated/FilterAggregateTranspose.java new file mode 100644 index 0000000..ce109d7 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/FilterAggregateTranspose.java @@ -0,0 +1,44 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class FilterAggregateTranspose extends RelRule { + protected FilterAggregateTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + var var_4 = ((LogicalAggregate) call.rel(1)).getGroupSet(); + var var_5 = var_3.push(call.rel(2)).filter(org.qed.Backends.Calcite.HelperFunctions.pushFilterPastAggregate(call)).groupKey(var_4); + var var_6 = ((LogicalAggregate) call.rel(1)).getAggCallList(); + call.transformTo(var_3.push(call.rel(2)).filter(org.qed.Backends.Calcite.HelperFunctions.pushFilterPastAggregate(call)).aggregate(var_5, var_6).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterAggregateTranspose toRule() { + return new FilterAggregateTranspose(this); + } + + @Override + default String description() { + return "FilterAggregateTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalAggregate.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/FilterIntoJoin.java b/src/main/java/org/qed/Backends/Calcite/Generated/FilterIntoJoin.java new file mode 100644 index 0000000..2d5a2c0 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/FilterIntoJoin.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class FilterIntoJoin extends RelRule { + protected FilterIntoJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterIntoJoin toRule() { + return new FilterIntoJoin(this); + } + + @Override + default String description() { + return "FilterIntoJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/FilterMerge.java b/src/main/java/org/qed/Backends/Calcite/Generated/FilterMerge.java new file mode 100644 index 0000000..3a61f9a --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/FilterMerge.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class FilterMerge extends RelRule { + protected FilterMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).filter(var_3.push(call.rel(2)).and(((LogicalFilter) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterMerge toRule() { + return new FilterMerge(this); + } + + @Override + default String description() { + return "FilterMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/FilterProjectTranspose.java b/src/main/java/org/qed/Backends/Calcite/Generated/FilterProjectTranspose.java new file mode 100644 index 0000000..a3faea3 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/FilterProjectTranspose.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class FilterProjectTranspose extends RelRule { + protected FilterProjectTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).filter(RelOptUtil.pushFilterPastProject(((LogicalFilter) call.rel(0)).getCondition(), ((LogicalProject) call.rel(1)))).project(((LogicalProject) call.rel(1)).getProjects()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterProjectTranspose toRule() { + return new FilterProjectTranspose(this); + } + + @Override + default String description() { + return "FilterProjectTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceFalse.java b/src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceFalse.java new file mode 100644 index 0000000..39fae06 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceFalse.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class FilterReduceFalse extends RelRule { + protected FilterReduceFalse(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterReduceFalse toRule() { + return new FilterReduceFalse(this); + } + + @Override + default String description() { + return "FilterReduceFalse"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceTrue.java b/src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceTrue.java new file mode 100644 index 0000000..bb34e6c --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/FilterReduceTrue.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class FilterReduceTrue extends RelRule { + protected FilterReduceTrue(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterReduceTrue toRule() { + return new FilterReduceTrue(this); + } + + @Override + default String description() { + return "FilterReduceTrue"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/FilterSetOpTranspose.java b/src/main/java/org/qed/Backends/Calcite/Generated/FilterSetOpTranspose.java new file mode 100644 index 0000000..bd2d3b2 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/FilterSetOpTranspose.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class FilterSetOpTranspose extends RelRule { + protected FilterSetOpTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).push(call.rel(3)).filter(((LogicalFilter) call.rel(0)).getCondition()).union(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterSetOpTranspose toRule() { + return new FilterSetOpTranspose(this); + } + + @Override + default String description() { + return "FilterSetOpTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/IntersectMerge.java b/src/main/java/org/qed/Backends/Calcite/Generated/IntersectMerge.java new file mode 100644 index 0000000..217447f --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/IntersectMerge.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class IntersectMerge extends RelRule { + protected IntersectMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).intersect(false, 3).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default IntersectMerge toRule() { + return new IntersectMerge(this); + } + + @Override + default String description() { + return "IntersectMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalIntersect.class).inputs(s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/JoinAddRedundantSemiJoin.java b/src/main/java/org/qed/Backends/Calcite/Generated/JoinAddRedundantSemiJoin.java new file mode 100644 index 0000000..7e99200 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/JoinAddRedundantSemiJoin.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class JoinAddRedundantSemiJoin extends RelRule { + protected JoinAddRedundantSemiJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(0)).getCondition()).push(call.rel(2)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinAddRedundantSemiJoin toRule() { + return new JoinAddRedundantSemiJoin(this); + } + + @Override + default String description() { + return "JoinAddRedundantSemiJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/JoinCommute.java b/src/main/java/org/qed/Backends/Calcite/Generated/JoinCommute.java new file mode 100644 index 0000000..ca57631 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/JoinCommute.java @@ -0,0 +1,55 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class JoinCommute extends RelRule { + protected JoinCommute(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + var var_4 = (LogicalJoin) call.rel(0); + var var_5 = (org.apache.calcite.rex.RexCall) var_4.getCondition(); + var var_6 = ((org.apache.calcite.rex.RexInputRef) var_5.getOperands().get(0)).getIndex(); + var var_7 = (LogicalJoin) call.rel(0); + var var_8 = (org.apache.calcite.rex.RexCall) var_7.getCondition(); + var var_9 = ((org.apache.calcite.rex.RexInputRef) var_8.getOperands().get(1)).getIndex(); + var var_10 = call.rel(1).getRowType().getFieldCount(); + var var_11 = var_9 - var_10; + var var_12 = call.rel(1); + var var_13 = call.rel(2); + var var_14 = var_12.getRowType().getFieldCount(); + var var_15 = var_13.getRowType().getFieldCount(); + var var_16 = java.util.stream.IntStream.concat(java.util.stream.IntStream.range(var_15, var_15 + var_14), java.util.stream.IntStream.range(0, var_15)).boxed().collect(java.util.stream.Collectors.toList()); + var var_17 = var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).fields(var_16); + call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).project(var_17).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinCommute toRule() { + return new JoinCommute(this); + } + + @Override + default String description() { + return "JoinCommute"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/JoinConditionPush.java b/src/main/java/org/qed/Backends/Calcite/Generated/JoinConditionPush.java new file mode 100644 index 0000000..a048315 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/JoinConditionPush.java @@ -0,0 +1,45 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class JoinConditionPush extends RelRule { + protected JoinConditionPush(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + var var_5 = call.builder(); + var var_6 = org.qed.Backends.Calcite.HelperFunctions.ConditionDecomposer.extractLeftOnlyConditions(((LogicalJoin) call.rel(0)).getCondition(), call.rel(1).getRowType().getFieldCount(), call); + var var_7 = org.qed.Backends.Calcite.HelperFunctions.ConditionDecomposer.extractRightOnlyConditions(((LogicalJoin) call.rel(0)).getCondition(), call.rel(1).getRowType().getFieldCount(), call.rel(1).getRowType().getFieldCount() + call.rel(2).getRowType().getFieldCount(), call); + var var_8 = org.qed.Backends.Calcite.HelperFunctions.ConditionDecomposer.extractJoinConditions(((LogicalJoin) call.rel(0)).getCondition(), call.rel(1).getRowType().getFieldCount(), call.rel(1).getRowType().getFieldCount() + call.rel(2).getRowType().getFieldCount(), call); + call.transformTo(var_5.push(call.rel(1)).filter(var_6).push(call.rel(2)).filter(var_7).join(JoinRelType.INNER, var_8).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinConditionPush toRule() { + return new JoinConditionPush(this); + } + + @Override + default String description() { + return "JoinConditionPush"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/JoinExtractFilter.java b/src/main/java/org/qed/Backends/Calcite/Generated/JoinExtractFilter.java new file mode 100644 index 0000000..a6b16ce --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/JoinExtractFilter.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class JoinExtractFilter extends RelRule { + protected JoinExtractFilter(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.INNER, var_3.push(call.rel(1)).push(call.rel(2)).literal(true)).filter(((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinExtractFilter toRule() { + return new JoinExtractFilter(this); + } + + @Override + default String description() { + return "JoinExtractFilter"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/JoinPushTransitivePredicates.java b/src/main/java/org/qed/Backends/Calcite/Generated/JoinPushTransitivePredicates.java new file mode 100644 index 0000000..8bc9c85 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/JoinPushTransitivePredicates.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class JoinPushTransitivePredicates extends RelRule { + protected JoinPushTransitivePredicates(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinPushTransitivePredicates toRule() { + return new JoinPushTransitivePredicates(this); + } + + @Override + default String description() { + return "JoinPushTransitivePredicates"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceFalse.java b/src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceFalse.java new file mode 100644 index 0000000..be0b531 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceFalse.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class JoinReduceFalse extends RelRule { + protected JoinReduceFalse(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.INNER, var_5.push(call.rel(1)).push(call.rel(2)).literal(false)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinReduceFalse toRule() { + return new JoinReduceFalse(this); + } + + @Override + default String description() { + return "JoinReduceFalse"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceTrue.java b/src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceTrue.java new file mode 100644 index 0000000..3d8fd86 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/JoinReduceTrue.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class JoinReduceTrue extends RelRule { + protected JoinReduceTrue(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinReduceTrue toRule() { + return new JoinReduceTrue(this); + } + + @Override + default String description() { + return "JoinReduceTrue"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/MinusMerge.java b/src/main/java/org/qed/Backends/Calcite/Generated/MinusMerge.java new file mode 100644 index 0000000..f18ab7d --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/MinusMerge.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class MinusMerge extends RelRule { + protected MinusMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).union(false, 2).minus(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default MinusMerge toRule() { + return new MinusMerge(this); + } + + @Override + default String description() { + return "MinusMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalMinus.class).inputs(s_2 -> s_2.operand(LogicalMinus.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/ProjectAggregateMerge.java b/src/main/java/org/qed/Backends/Calcite/Generated/ProjectAggregateMerge.java new file mode 100644 index 0000000..00e4508 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/ProjectAggregateMerge.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class ProjectAggregateMerge extends RelRule { + protected ProjectAggregateMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.projectAggregateMerge(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default ProjectAggregateMerge toRule() { + return new ProjectAggregateMerge(this); + } + + @Override + default String description() { + return "ProjectAggregateMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalAggregate.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/ProjectFilterTranspose.java b/src/main/java/org/qed/Backends/Calcite/Generated/ProjectFilterTranspose.java new file mode 100644 index 0000000..ac15097 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/ProjectFilterTranspose.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class ProjectFilterTranspose extends RelRule { + protected ProjectFilterTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).filter(org.qed.Backends.Calcite.HelperFunctions.mapFilterToProjectedColumns(call)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default ProjectFilterTranspose toRule() { + return new ProjectFilterTranspose(this); + } + + @Override + default String description() { + return "ProjectFilterTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/ProjectMerge.java b/src/main/java/org/qed/Backends/Calcite/Generated/ProjectMerge.java new file mode 100644 index 0000000..db585d2 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/ProjectMerge.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class ProjectMerge extends RelRule { + protected ProjectMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.mergeProjections(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default ProjectMerge toRule() { + return new ProjectMerge(this); + } + + @Override + default String description() { + return "ProjectMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyFilter.java b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyFilter.java new file mode 100644 index 0000000..8adfccc --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyFilter.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class PruneEmptyFilter extends RelRule { + protected PruneEmptyFilter(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyFilter toRule() { + return new PruneEmptyFilter(this); + } + + @Override + default String description() { + return "PruneEmptyFilter"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyIntersect.java b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyIntersect.java new file mode 100644 index 0000000..e6a77e0 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyIntersect.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class PruneEmptyIntersect extends RelRule { + protected PruneEmptyIntersect(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).empty().push(call.rel(2)).intersect(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyIntersect toRule() { + return new PruneEmptyIntersect(this); + } + + @Override + default String description() { + return "PruneEmptyIntersect"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyMinus.java b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyMinus.java new file mode 100644 index 0000000..78d0ed9 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyMinus.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class PruneEmptyMinus extends RelRule { + protected PruneEmptyMinus(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyMinus toRule() { + return new PruneEmptyMinus(this); + } + + @Override + default String description() { + return "PruneEmptyMinus"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalMinus.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyProject.java b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyProject.java new file mode 100644 index 0000000..98fa167 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyProject.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class PruneEmptyProject extends RelRule { + protected PruneEmptyProject(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyProject toRule() { + return new PruneEmptyProject(this); + } + + @Override + default String description() { + return "PruneEmptyProject"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyUnion.java b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyUnion.java new file mode 100644 index 0000000..d827fcd --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/PruneEmptyUnion.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class PruneEmptyUnion extends RelRule { + protected PruneEmptyUnion(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyUnion toRule() { + return new PruneEmptyUnion(this); + } + + @Override + default String description() { + return "PruneEmptyUnion"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/SemiJoinFilterTranspose.java b/src/main/java/org/qed/Backends/Calcite/Generated/SemiJoinFilterTranspose.java new file mode 100644 index 0000000..9b3f8b2 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/SemiJoinFilterTranspose.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class SemiJoinFilterTranspose extends RelRule { + protected SemiJoinFilterTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).push(call.rel(3)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinFilterTranspose toRule() { + return new SemiJoinFilterTranspose(this); + } + + @Override + default String description() { + return "SemiJoinFilterTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/UnionMerge.java b/src/main/java/org/qed/Backends/Calcite/Generated/UnionMerge.java new file mode 100644 index 0000000..4effcbb --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/UnionMerge.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class UnionMerge extends RelRule { + protected UnionMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).union(false, 3).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default UnionMerge toRule() { + return new UnionMerge(this); + } + + @Override + default String description() { + return "UnionMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalUnion.class).inputs(s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/UnionPullUpConstants.java b/src/main/java/org/qed/Backends/Calcite/Generated/UnionPullUpConstants.java new file mode 100644 index 0000000..78c8741 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/UnionPullUpConstants.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class UnionPullUpConstants extends RelRule { + protected UnionPullUpConstants(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.unionPullUpConstants(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default UnionPullUpConstants toRule() { + return new UnionPullUpConstants(this); + } + + @Override + default String description() { + return "UnionPullUpConstants"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalUnion.class).predicate(union -> union.getRowType().getFieldCount() > 1).anyInputs(); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Generated/UnionToDistinct.java b/src/main/java/org/qed/Backends/Calcite/Generated/UnionToDistinct.java new file mode 100644 index 0000000..e2de365 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Generated/UnionToDistinct.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; +import org.qed.Backends.Calcite.EmptyConfig; + +public class UnionToDistinct extends RelRule { + protected UnionToDistinct(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(org.qed.Backends.Calcite.HelperFunctions.unionToDistinct(call).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default UnionToDistinct toRule() { + return new UnionToDistinct(this); + } + + @Override + default String description() { + return "UnionToDistinct"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalUnion.class).predicate(union -> !union.all).anyInputs(); + } + + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/HelperFunctions.java b/src/main/java/org/qed/Backends/Calcite/HelperFunctions.java new file mode 100644 index 0000000..92956bf --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/HelperFunctions.java @@ -0,0 +1,649 @@ +package org.qed.Backends.Calcite; + +import java.util.Set; +import java.util.Map; +import java.util.List; +import java.util.HashMap; +import java.util.HashSet; +import java.util.ArrayList; +import java.util.Collections; + +import org.qed.RuleBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalUnion; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexInputRef; + +public class HelperFunctions { + public List compose(RelNode base, List inner, List outer) { + var builder = RuleBuilder.create(); + return RelOptUtil.pushPastProject(outer, (Project) builder.push(base).project(inner).build()); + } + + public static org.apache.calcite.rex.RexNode mapFilterToProjectedColumns(RelOptRuleCall call) { + var filter = (LogicalFilter) call.rel(1); + var project = (LogicalProject) call.rel(0); + var rexBuilder = project.getCluster().getRexBuilder(); + var tableToProjectMapping = new HashMap(); + for (int projectedPos = 0; projectedPos < project.getProjects().size(); projectedPos++) { + var projectExpr = project.getProjects().get(projectedPos); + if (projectExpr instanceof RexInputRef inputRef) { + tableToProjectMapping.put(inputRef.getIndex(), projectedPos); + } + } + return filter.getCondition().accept(new RexShuttle() { + @Override + public org.apache.calcite.rex.RexNode visitInputRef(RexInputRef inputRef) { + Integer projectedPos = tableToProjectMapping.get(inputRef.getIndex()); + if (projectedPos != null) { + return rexBuilder.makeInputRef(inputRef.getType(), projectedPos); + } + return inputRef; + } + }); + } + + public static org.apache.calcite.rex.RexNode mapFilterToAggregatedColumns(RelOptRuleCall call) { + var filter = (LogicalFilter) call.rel(1); + var aggregate = (LogicalAggregate) call.rel(0); + var rexBuilder = aggregate.getCluster().getRexBuilder(); + var inputToAggregateMapping = new HashMap(); + int outputPos = 0; + for (int groupCol : aggregate.getGroupSet()) { + inputToAggregateMapping.put(groupCol, outputPos++); + } + return filter.getCondition().accept(new RexShuttle() { + @Override + public org.apache.calcite.rex.RexNode visitInputRef(RexInputRef inputRef) { + Integer aggregatedPos = inputToAggregateMapping.get(inputRef.getIndex()); + if (aggregatedPos != null) { + return rexBuilder.makeInputRef(inputRef.getType(), aggregatedPos); + } + throw new IllegalStateException( + "Filter references non-group column at index " + inputRef.getIndex() + + " which cannot be pushed past aggregate"); + } + }); + } + + public static org.apache.calcite.rex.RexNode pushFilterPastAggregate(RelOptRuleCall call) { + var filter = (LogicalFilter) call.rel(0); + var aggregate = (LogicalAggregate) call.rel(1); + var rexBuilder = aggregate.getCluster().getRexBuilder(); + var aggregateToInputMapping = new HashMap(); + int outputPos = 0; + for (int inputCol : aggregate.getGroupSet()) { + aggregateToInputMapping.put(outputPos++, inputCol); + } + return filter.getCondition().accept(new RexShuttle() { + @Override + public org.apache.calcite.rex.RexNode visitInputRef(RexInputRef inputRef) { + Integer originalPos = aggregateToInputMapping.get(inputRef.getIndex()); + if (originalPos != null) { + return rexBuilder.makeInputRef(inputRef.getType(), originalPos); + } + throw new IllegalStateException( + "Filter references non-group column at index " + inputRef.getIndex() + + " which cannot be pushed past aggregate"); + } + }); + } + + public static boolean canMergeAggregateProject(RelOptRuleCall call) { + var aggregate = (LogicalAggregate) call.rel(0); + var project = (LogicalProject) call.rel(1); + var interestingFields = org.apache.calcite.plan.RelOptUtil.getAllFields(aggregate); + for (int fieldIndex : interestingFields) { + var projectExpr = project.getProjects().get(fieldIndex); + if (!(projectExpr instanceof RexInputRef)) { + return false; + } + } + return true; + } + + public static org.apache.calcite.tools.RelBuilder createMergedAggregateProject(RelOptRuleCall call) { + var aggregate = (LogicalAggregate) call.rel(0); + var project = (LogicalProject) call.rel(1); + var builder = call.builder(); + var interestingFields = org.apache.calcite.plan.RelOptUtil.getAllFields(aggregate); + var fieldMapping = new HashMap(); + for (int fieldIndex : interestingFields) { + var projectExpr = project.getProjects().get(fieldIndex); + if (projectExpr instanceof RexInputRef inputRef) { + fieldMapping.put(fieldIndex, inputRef.getIndex()); + } + } + builder.push(project.getInput()); + var newGroupSet = aggregate.getGroupSet().permute(fieldMapping); + var groupKey = builder.groupKey(newGroupSet); + var mappedAggCalls = new java.util.ArrayList(); + var sourceCount = aggregate.getInput().getRowType().getFieldCount(); + var targetCount = project.getInput().getRowType().getFieldCount(); + var targetMapping = org.apache.calcite.util.mapping.Mappings.target( + fieldMapping, + sourceCount, + targetCount + ); + + for (var aggCall : aggregate.getAggCallList()) { + mappedAggCalls.add(aggCall.transform(targetMapping)); + } + builder.aggregate(groupKey, mappedAggCalls); + + var originalGroupList = aggregate.getGroupSet().asList(); + var newGroupList = newGroupSet.asList(); + var reorderingIndices = new java.util.ArrayList(); + for (int originalFieldIndex : originalGroupList) { + int mappedFieldIndex = fieldMapping.get(originalFieldIndex); + int positionInNewAggregate = newGroupList.indexOf(mappedFieldIndex); + reorderingIndices.add(positionInNewAggregate); + } + for (int i = aggregate.getGroupCount(); i < aggregate.getGroupCount() + aggregate.getAggCallList().size(); i++) { + reorderingIndices.add(i); + } + builder.project(builder.fields(reorderingIndices)); + + return builder; + } + + public static org.apache.calcite.tools.RelBuilder mergeProjections(RelOptRuleCall call) { + var outerProject = (LogicalProject) call.rel(0); + var innerProject = (LogicalProject) call.rel(1); + var source = call.rel(2); + var builder = call.builder(); + builder.push(source); + var composedExpressions = new java.util.ArrayList(); + var rexBuilder = builder.getRexBuilder(); + + for (var outerExpr : outerProject.getProjects()) { + var composedExpr = outerExpr.accept(new org.apache.calcite.rex.RexShuttle() { + @Override + public org.apache.calcite.rex.RexNode visitInputRef(org.apache.calcite.rex.RexInputRef inputRef) { + int fieldIndex = inputRef.getIndex(); + if (fieldIndex < innerProject.getProjects().size()) { + return innerProject.getProjects().get(fieldIndex); + } + return inputRef; + } + }); + composedExpressions.add(composedExpr); + } + builder.project(composedExpressions); + + return builder; + } + + public static class ConditionDecomposer { + public static RexNode extractLeftOnlyConditions(RexNode condition, int leftFieldCount, RelOptRuleCall call) { + List leftConditions = new ArrayList<>(); + extractConditionsForSide(condition, leftConditions, 0, leftFieldCount - 1); + if (leftConditions.isEmpty()) return null; + if (leftConditions.size() == 1) return leftConditions.get(0); + return RexUtil.composeConjunction(call.builder().getRexBuilder(), leftConditions); + } + + public static RexNode extractRightOnlyConditions(RexNode condition, int leftFieldCount, int totalFieldCount, RelOptRuleCall call) { + List rightConditions = new ArrayList<>(); + extractConditionsForSide(condition, rightConditions, leftFieldCount, totalFieldCount - 1); + if (rightConditions.isEmpty()) return null; + org.apache.calcite.rex.RexBuilder rexBuilder = call.builder().getRexBuilder(); + List adjustedConditions = new ArrayList<>(); + for (RexNode cond : rightConditions) { + adjustedConditions.add(adjustFieldIndices(cond, -leftFieldCount, rexBuilder)); + } + if (adjustedConditions.size() == 1) return adjustedConditions.get(0); + return RexUtil.composeConjunction(rexBuilder, adjustedConditions); + } + + public static RexNode extractJoinConditions(RexNode condition, int leftFieldCount, int totalFieldCount, RelOptRuleCall call) { + List joinConditions = new ArrayList<>(); + extractCrossTableConditions(condition, joinConditions, leftFieldCount, totalFieldCount); + if (joinConditions.isEmpty()) return null; + if (joinConditions.size() == 1) return joinConditions.get(0); + return RexUtil.composeConjunction(call.builder().getRexBuilder(), joinConditions); + } + + private static void extractConditionsForSide(RexNode condition, List result, int minField, int maxField) { + if (condition instanceof RexCall call && call.getOperator().getKind() == org.apache.calcite.sql.SqlKind.AND) { + for (RexNode operand : call.getOperands()) { + extractConditionsForSide(operand, result, minField, maxField); + } + } else if (referencesOnlyFields(condition, minField, maxField)) { + result.add(condition); + } + } + + private static void extractCrossTableConditions(RexNode condition, List result, int leftFieldCount, int totalFieldCount) { + if (condition instanceof RexCall call && call.getOperator().getKind() == org.apache.calcite.sql.SqlKind.AND) { + for (RexNode operand : call.getOperands()) { + extractCrossTableConditions(operand, result, leftFieldCount, totalFieldCount); + } + } else if (referencesBothSides(condition, leftFieldCount, totalFieldCount)) { + result.add(condition); + } + } + + private static boolean referencesOnlyFields(RexNode condition, int minField, int maxField) { + Set fields = new HashSet<>(); + collectFieldReferences(condition, fields); + return !fields.isEmpty() && fields.stream().allMatch(f -> f >= minField && f <= maxField); + } + + private static boolean referencesBothSides(RexNode condition, int leftFieldCount, int totalFieldCount) { + Set fields = new HashSet<>(); + collectFieldReferences(condition, fields); + boolean hasLeft = fields.stream().anyMatch(f -> f < leftFieldCount); + boolean hasRight = fields.stream().anyMatch(f -> f >= leftFieldCount && f < totalFieldCount); + return hasLeft && hasRight; + } + + private static void collectFieldReferences(RexNode node, Set fields) { + if (node instanceof RexInputRef inputRef) { + fields.add(inputRef.getIndex()); + } else if (node instanceof RexCall call) { + for (RexNode operand : call.getOperands()) { + collectFieldReferences(operand, fields); + } + } + } + + private static RexNode adjustFieldIndices(RexNode node, int offset, org.apache.calcite.rex.RexBuilder rexBuilder) { + if (node instanceof RexInputRef inputRef) { + return rexBuilder.makeInputRef(inputRef.getType(), inputRef.getIndex() + offset); + } else if (node instanceof RexCall call) { + List newOperands = new ArrayList<>(); + for (RexNode operand : call.getOperands()) { + newOperands.add(adjustFieldIndices(operand, offset, rexBuilder)); + } + return rexBuilder.makeCall(call.getOperator(), newOperands); + } + return node; + } + } + + public static org.apache.calcite.tools.RelBuilder extractProjectForAggregate(RelOptRuleCall call) { + var builder = call.builder(); + LogicalAggregate aggregate = (LogicalAggregate) call.rel(0); + RelNode input = call.rel(1); + Set usedFields = new HashSet<>(); + for (int field : aggregate.getGroupSet()) { + usedFields.add(field); + } + for (AggregateCall aggCall : aggregate.getAggCallList()) { + for (int field : aggCall.getArgList()) { + usedFields.add(field); + } + if (aggCall.filterArg >= 0) { + usedFields.add(aggCall.filterArg); + } + } + List sortedFields = new ArrayList<>(usedFields); + Collections.sort(sortedFields); + Map fieldMapping = new HashMap<>(); + for (int i = 0; i < sortedFields.size(); i++) { + fieldMapping.put(sortedFields.get(i), i); + } + builder.push(input); + List projectedFields = new ArrayList<>(); + for (int field : sortedFields) { + projectedFields.add(builder.field(field)); + } + builder.project(projectedFields); + ImmutableBitSet.Builder newGroupSet = ImmutableBitSet.builder(); + for (int field : aggregate.getGroupSet()) { + newGroupSet.set(fieldMapping.get(field)); + } + + List newAggCalls = new ArrayList<>(); + for (AggregateCall aggCall : aggregate.getAggCallList()) { + List newArgList = new ArrayList<>(); + for (int field : aggCall.getArgList()) { + newArgList.add(fieldMapping.get(field)); + } + int newFilterArg = aggCall.filterArg >= 0 ? fieldMapping.get(aggCall.filterArg) : -1; + + newAggCalls.add(aggCall.adaptTo( + builder.peek(), + newArgList, + newFilterArg, + aggregate.getGroupCount(), + aggregate.getGroupCount() + )); + } + + builder.aggregate( + builder.groupKey(newGroupSet.build()), + newAggCalls + ); + + return builder; + } + + public static org.apache.calcite.tools.RelBuilder aggregateJoinJoinRemove(RelOptRuleCall call) { + var builder = call.builder(); + var originalAgg = (LogicalAggregate) call.rel(0); + var groupSet = originalAgg.getGroupSet(); + var aggCalls = originalAgg.getAggCallList(); + + // 计算tblA和tblB的字段数量 + int tblAFieldCount = call.rel(3).getRowType().getFieldCount(); + int tblBFieldCount = call.rel(6).getRowType().getFieldCount(); + + // 重新映射字段索引:移除tblB的字段 + // Before: tblA[0..tblAFieldCount-1], tblB[tblAFieldCount..tblAFieldCount+tblBFieldCount-1], tblC[tblAFieldCount+tblBFieldCount..] + // After: tblA[0..tblAFieldCount-1], tblC[tblAFieldCount..] + ImmutableBitSet.Builder newGroupSetBuilder = ImmutableBitSet.builder(); + for (int field : groupSet) { + if (field < tblAFieldCount) { + // 来自tblA,索引不变 + newGroupSetBuilder.set(field); + } else if (field >= tblAFieldCount + tblBFieldCount) { + // 来自tblC,减去tblB的字段数 + newGroupSetBuilder.set(field - tblBFieldCount); + } + // 如果field在tblB范围内,跳过(不应该发生,因为规则要求aggregate不使用tblB) + } + + return builder + .push(call.rel(4)) // scanA1 + .push(call.rel(5)) // scanA2 + .join(JoinRelType.INNER, builder.literal(true)) + .push(call.rel(10)) // scanC1 + .push(call.rel(11)) // scanC2 + .join(JoinRelType.INNER, builder.literal(true)) + .join(JoinRelType.LEFT, builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0))) + .aggregate(builder.groupKey(newGroupSetBuilder.build()), aggCalls); + } + + public static org.apache.calcite.tools.RelBuilder aggregateProjectConstantToDummyJoin(RelOptRuleCall call) { + final LogicalAggregate aggregate = call.rel(0); + final LogicalProject project = call.rel(1); + + RelBuilder builder = call.builder(); + org.apache.calcite.rex.RexBuilder rexBuilder = builder.getRexBuilder(); + + builder.push(project.getInput()); + int offset = project.getInput().getRowType().getFieldCount(); + + org.apache.calcite.rel.type.RelDataTypeFactory.Builder valuesType = + rexBuilder.getTypeFactory().builder(); + java.util.List literals = new java.util.ArrayList<>(); + java.util.List projects = project.getProjects(); + + int colIndex = 0; + for (int i = 0; i < projects.size(); i++) { + org.apache.calcite.rex.RexNode node = projects.get(i); + if (node instanceof org.apache.calcite.rex.RexLiteral) { + literals.add((org.apache.calcite.rex.RexLiteral) node); + // 使用统一的命名 "col0", "col1", ... + valuesType.add("col" + colIndex++, node.getType()); + } + } + + builder.values(com.google.common.collect.ImmutableList.of(literals), valuesType.build()); + builder.join(org.apache.calcite.rel.core.JoinRelType.INNER, rexBuilder.makeLiteral(true)); + + java.util.List newProjects = new java.util.ArrayList<>(); + int literalCounter = 0; + for (org.apache.calcite.rex.RexNode exp : project.getProjects()) { + if (exp instanceof org.apache.calcite.rex.RexLiteral) { + newProjects.add(builder.field(offset + literalCounter++)); + } else { + newProjects.add(exp); + } + } + + builder.project(newProjects); + builder.aggregate( + builder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()), + aggregate.getAggCallList()); + + return builder; + } + + public static org.apache.calcite.tools.RelBuilder unionToDistinct(RelOptRuleCall call) { + final LogicalUnion union = call.rel(0); + final RelBuilder relBuilder = call.builder(); + relBuilder.pushAll(union.getInputs()); + relBuilder.union(true, union.getInputs().size()); + relBuilder.distinct(); + return relBuilder; + } + + public static org.apache.calcite.tools.RelBuilder unionPullUpConstants(RelOptRuleCall call) { + final LogicalUnion union = call.rel(0); + final org.apache.calcite.rex.RexBuilder rexBuilder = union.getCluster().getRexBuilder(); + final org.apache.calcite.rel.metadata.RelMetadataQuery mq = call.getMetadataQuery(); + final org.apache.calcite.plan.RelOptPredicateList predicates = mq.getPulledUpPredicates(union); + + if (org.apache.calcite.plan.RelOptPredicateList.isEmpty(predicates)) { + return call.builder().push(union); + } + + final java.util.Map constants = new java.util.HashMap<>(); + for (java.util.Map.Entry e : predicates.constantMap.entrySet()) { + if (e.getKey() instanceof org.apache.calcite.rex.RexInputRef) { + constants.put(((org.apache.calcite.rex.RexInputRef) e.getKey()).getIndex(), e.getValue()); + } + } + + if (constants.isEmpty()) { + return call.builder().push(union); + } + + java.util.List fields = union.getRowType().getFieldList(); + java.util.List topChildExprs = new java.util.ArrayList<>(); + java.util.List topChildExprsFields = new java.util.ArrayList<>(); + java.util.List refs = new java.util.ArrayList<>(); + org.apache.calcite.util.ImmutableBitSet.Builder refsIndexBuilder = org.apache.calcite.util.ImmutableBitSet.builder(); + + for (org.apache.calcite.rel.type.RelDataTypeField field : fields) { + final org.apache.calcite.rex.RexNode constant = constants.get(field.getIndex()); + if (constant != null) { + if (constant.getType().equals(field.getType())) { + topChildExprs.add(constant); + } else { + topChildExprs.add(rexBuilder.makeCast(field.getType(), constant, true, false)); + } + topChildExprsFields.add(field.getName()); + } else { + final org.apache.calcite.rex.RexNode expr = rexBuilder.makeInputRef(union, field.getIndex()); + topChildExprs.add(expr); + topChildExprsFields.add(field.getName()); + refs.add(expr); + refsIndexBuilder.set(field.getIndex()); + } + } + org.apache.calcite.util.ImmutableBitSet refsIndex = refsIndexBuilder.build(); + + final org.apache.calcite.util.mapping.Mappings.TargetMapping mapping = + org.apache.calcite.plan.RelOptUtil.permutation(refs, union.getInput(0).getRowType()).inverse(); + topChildExprs = org.apache.calcite.rex.RexUtil.apply(mapping, topChildExprs); + + final RelBuilder relBuilder = call.builder(); + for (org.apache.calcite.rel.RelNode input : union.getInputs()) { + java.util.List> newChildExprs = + new java.util.ArrayList<>(); + for (int j : refsIndex) { + newChildExprs.add( + org.apache.calcite.util.Pair.of( + rexBuilder.makeInputRef(input, j), + input.getRowType().getFieldList().get(j).getName())); + } + if (newChildExprs.isEmpty()) { + // At least a single item in project is required. + newChildExprs.add( + org.apache.calcite.util.Pair.of(topChildExprs.get(0), topChildExprsFields.get(0))); + } + // Add the input with project on top + relBuilder.push(input); + relBuilder.project( + org.apache.calcite.util.Pair.left(newChildExprs), + org.apache.calcite.util.Pair.right(newChildExprs)); + } + relBuilder.union(union.all, union.getInputs().size()); + // Create top Project fixing nullability of fields + relBuilder.project(topChildExprs, topChildExprsFields); + relBuilder.convert(union.getRowType(), false); + + return relBuilder; + } + + public static org.apache.calcite.tools.RelBuilder projectAggregateMerge(RelOptRuleCall call) { + final LogicalProject project = call.rel(0); + final LogicalAggregate aggregate = call.rel(1); + final org.apache.calcite.plan.RelOptCluster cluster = aggregate.getCluster(); + + // Do a quick check. If all aggregate calls are used, and there are no CASE + // expressions, there is nothing to do. + final org.apache.calcite.util.ImmutableBitSet bits = + org.apache.calcite.plan.RelOptUtil.InputFinder.bits(project.getProjects(), null); + if (bits.contains( + org.apache.calcite.util.ImmutableBitSet.range(aggregate.getGroupCount(), + aggregate.getRowType().getFieldCount())) + && kindCount(project.getProjects(), org.apache.calcite.sql.SqlKind.CASE) == 0) { + return null; + } + + // Replace 'COALESCE(SUM(x), 0)' with 'SUM0(x)' wherever it occurs. + // Add 'SUM0(x)' to the aggregate call list, if necessary. + final java.util.List aggCallList = + new java.util.ArrayList<>(aggregate.getAggCallList()); + + final org.apache.calcite.rex.RexShuttle shuttle = new org.apache.calcite.rex.RexShuttle() { + @Override public org.apache.calcite.rex.RexNode visitCall(org.apache.calcite.rex.RexCall call) { + switch (call.getKind()) { + case CASE: + // Do we have "CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)"? + final java.util.List operands = call.operands; + if (operands.size() == 3 + && operands.get(0).getKind() == org.apache.calcite.sql.SqlKind.IS_NOT_NULL + && ((org.apache.calcite.rex.RexCall) operands.get(0)).operands.get(0).getKind() + == org.apache.calcite.sql.SqlKind.INPUT_REF + && operands.get(1).getKind() == org.apache.calcite.sql.SqlKind.CAST + && ((org.apache.calcite.rex.RexCall) operands.get(1)).operands.get(0).getKind() + == org.apache.calcite.sql.SqlKind.INPUT_REF + && operands.get(2).getKind() == org.apache.calcite.sql.SqlKind.LITERAL) { + final org.apache.calcite.rex.RexCall isNotNull = (org.apache.calcite.rex.RexCall) operands.get(0); + final org.apache.calcite.rex.RexInputRef ref0 = (org.apache.calcite.rex.RexInputRef) isNotNull.operands.get(0); + final org.apache.calcite.rex.RexCall cast = (org.apache.calcite.rex.RexCall) operands.get(1); + final org.apache.calcite.rex.RexInputRef ref1 = (org.apache.calcite.rex.RexInputRef) cast.operands.get(0); + if (ref0.getIndex() != ref1.getIndex()) { + break; + } + final int aggCallIndex = ref1.getIndex() - aggregate.getGroupCount(); + if (aggCallIndex < 0) { + break; + } + final org.apache.calcite.rel.core.AggregateCall aggCall = aggregate.getAggCallList().get(aggCallIndex); + if (aggCall.getAggregation().getKind() != org.apache.calcite.sql.SqlKind.SUM) { + break; + } + final org.apache.calcite.rex.RexLiteral literal = (org.apache.calcite.rex.RexLiteral) operands.get(2); + if (java.util.Objects.equals(literal.getValueAs(java.math.BigDecimal.class), java.math.BigDecimal.ZERO)) { + int j = findSum0(cluster.getTypeFactory(), aggCall, aggCallList); + return cluster.getRexBuilder().makeInputRef(aggCallList.get(j).getType(), j); + } + } + break; + default: + break; + } + return super.visitCall(call); + } + }; + + final java.util.List projects2 = shuttle.visitList(project.getProjects()); + final org.apache.calcite.util.ImmutableBitSet bits2 = + org.apache.calcite.plan.RelOptUtil.InputFinder.bits(projects2, null); + + // Build the mapping that we will apply to the project expressions. + final org.apache.calcite.util.mapping.Mappings.TargetMapping mapping = + org.apache.calcite.util.mapping.Mappings.create( + org.apache.calcite.util.mapping.MappingType.FUNCTION, + aggregate.getGroupCount() + aggCallList.size(), -1); + int j = 0; + for (int i = 0; i < mapping.getSourceCount(); i++) { + if (i < aggregate.getGroupCount()) { + // Field is a group key. All group keys are retained. + mapping.set(i, j++); + } else if (bits2.get(i)) { + // Field is an aggregate call. It is used. + mapping.set(i, j++); + } else { + // Field is an aggregate call. It is not used. Remove it. + aggCallList.remove(j - aggregate.getGroupCount()); + } + } + + final RelBuilder builder = call.builder(); + builder.push(aggregate.getInput()); + builder.aggregate( + builder.groupKey(aggregate.getGroupSet(), aggregate.groupSets), aggCallList); + builder.project( + org.apache.calcite.rex.RexPermuteInputsShuttle.of(mapping).visitList(projects2), + project.getRowType().getFieldNames()); + builder.convert(project.getRowType(), true); + + return builder; + } + + private static int findSum0( + org.apache.calcite.rel.type.RelDataTypeFactory typeFactory, + org.apache.calcite.rel.core.AggregateCall sum, + java.util.List aggCallList) { + + final org.apache.calcite.rel.core.AggregateCall sum0 = + org.apache.calcite.rel.core.AggregateCall.create( + org.apache.calcite.sql.fun.SqlStdOperatorTable.SUM0, + sum.isDistinct(), + sum.isApproximate(), + false, // ignoreNulls + sum.getArgList(), // List + sum.filterArg, // int + sum.collation, // RelCollation + typeFactory.createTypeWithNullability(sum.type, false), // RelDataType + null); // String name + + final int i = aggCallList.indexOf(sum0); + if (i >= 0) { + return i; + } + aggCallList.add(sum0); + return aggCallList.size() - 1; + } + + private static int kindCount( + Iterable nodes, + final org.apache.calcite.sql.SqlKind kind) { + + final java.util.concurrent.atomic.AtomicInteger kindCount = + new java.util.concurrent.atomic.AtomicInteger(0); + + new org.apache.calcite.rex.RexVisitorImpl(true) { + @Override public Void visitCall(org.apache.calcite.rex.RexCall call) { + if (call.getKind() == kind) { + kindCount.incrementAndGet(); + } + return super.visitCall(call); + } + }.visitEach(nodes); + + return kindCount.get(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/AggregateExtractProjectTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateExtractProjectTest.java new file mode 100644 index 0000000..63d230f --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateExtractProjectTest.java @@ -0,0 +1,58 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class AggregateExtractProjectTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var empTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // EMPNO (index 0) + Tuple.of(RelType.fromString("VARCHAR", true), false), // ENAME (index 1) + Tuple.of(RelType.fromString("VARCHAR", true), false), // JOB (index 2) + Tuple.of(RelType.fromString("INTEGER", true), false), // MGR (index 3) + Tuple.of(RelType.fromString("DATE", true), false), // HIREDATE (index 4) + Tuple.of(RelType.fromString("DECIMAL", true), false), // SAL (index 5) + Tuple.of(RelType.fromString("DECIMAL", true), false), // COMM (index 6) + Tuple.of(RelType.fromString("INTEGER", true), false) // DEPTNO (index 7) + )); + builder.addTable(empTable); + + var empScan = builder.scan(empTable.getName()).build(); + + var before = builder + .push(empScan) + .aggregate( + builder.groupKey(builder.field(0), builder.field(7)), // Group by EMPNO, DEPTNO + builder.sum(builder.field(5)) // SUM(SAL) + ) + .build(); + + var after = builder + .push(empScan) + .project( + builder.field(0), // DEPTNO -> X (position 0 in projection) + builder.field(5), // EMPNO -> Y (position 1 in projection) + builder.field(7) // SAL -> Z (position 2 in projection) + ) + .aggregate( + builder.groupKey(builder.field(0), builder.field(2)), // Group by X, Y + builder.sum(builder.field(1)) // SUM(Z) + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.AggregateExtractProject.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running AggregateExtractProject comprehensive test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/AggregateFilterTransposeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateFilterTransposeTest.java new file mode 100644 index 0000000..c761c3a --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateFilterTransposeTest.java @@ -0,0 +1,59 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class AggregateFilterTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var sourceTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // ID (index 0) + Tuple.of(RelType.fromString("VARCHAR", true), false), // CATEGORY1 (index 1) + Tuple.of(RelType.fromString("VARCHAR", true), false), // CATEGORY2 (index 2) + Tuple.of(RelType.fromString("DECIMAL", true), false) // AMOUNT (index 3) + )); + builder.addTable(sourceTable); + + var sourceScan = builder.scan(sourceTable.getName()).build(); + + var before = builder + .push(sourceScan) + .filter(builder.call( + builder.genericPredicateOp("pred", true), + builder.field(1), builder.field(2) + )) + .aggregate( + builder.groupKey(builder.field(1), builder.field(2)), + builder.sum(builder.field(3)) + ) + .build(); + + var after = builder + .push(sourceScan) + .aggregate( + builder.groupKey(builder.field(1), builder.field(2)), + builder.sum(builder.field(3)) + ) + .filter(builder.call( + builder.genericPredicateOp("pred", true), + builder.field(0), builder.field(1) + )) + .build(); + + var runner = CalciteTester.loadRules( + org.qed.Backends.Calcite.Generated.AggregateFilterTranspose.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running AggregateFilterTranspose comprehensive test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinJoinRemoveTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinJoinRemoveTest.java new file mode 100644 index 0000000..860384c --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinJoinRemoveTest.java @@ -0,0 +1,122 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class AggregateJoinJoinRemoveTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // 简化:每个source table只有一个INTEGER字段 + var sourceA1 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceA1); + + var sourceA2 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceA2); + + var sourceB1 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceB1); + + var sourceB2 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceB2); + + var sourceC1 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceC1); + + var sourceC2 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceC2); + + // Build scans + var scanA1 = builder.scan(sourceA1.getName()).build(); + var scanA2 = builder.scan(sourceA2.getName()).build(); + var scanB1 = builder.scan(sourceB1.getName()).build(); + var scanB2 = builder.scan(sourceB2.getName()).build(); + var scanC1 = builder.scan(sourceC1.getName()).build(); + var scanC2 = builder.scan(sourceC2.getName()).build(); + + // Before: ((scanA1 JOIN scanA2) LEFT JOIN (scanB1 JOIN scanB2)) LEFT JOIN (scanC1 JOIN scanC2) + // tblA: 2 fields [0,1] + // tblB: 2 fields [2,3] + // tblC: 2 fields [4,5] + // After first LEFT JOIN (bottomJoin): 4 fields [0,1,2,3] + // After second LEFT JOIN (topJoin): 6 fields [0,1,2,3,4,5] + // Aggregate on field(0) from tblA and field(4) from tblC (NOT field(2) from tblB!) + var before = builder + .push(scanA1) + .push(scanA2) + .join(JoinRelType.INNER, builder.literal(true)) + .push(scanB1) + .push(scanB2) + .join(JoinRelType.INNER, builder.literal(true)) + .join(JoinRelType.LEFT, + builder.equals( + builder.field(2, 0, 0), // Left (tblA), field 0 + builder.field(2, 1, 0) // Right (tblB), field 0 + ) + ) + .push(scanC1) + .push(scanC2) + .join(JoinRelType.INNER, builder.literal(true)) + .join(JoinRelType.LEFT, + builder.equals( + builder.field(2, 0, 0), // Left (bottomJoin), field 0 (tblA.field0) + builder.field(2, 1, 0) // Right (tblC), field 0 + ) + ) + .aggregate( + builder.groupKey(builder.field(0), builder.field(4)) // tblA.f0 and tblC.f0 + ) + .build(); + + // After: (scanA1 JOIN scanA2) LEFT JOIN (scanC1 JOIN scanC2) + // tblA: 2 fields [0,1] + // tblC: 2 fields [2,3] + // After LEFT JOIN: 4 fields [0,1,2,3] + // Aggregate on field(0) from tblA and field(2) from tblC + var after = builder + .push(scanA1) + .push(scanA2) + .join(JoinRelType.INNER, builder.literal(true)) + .push(scanC1) + .push(scanC2) + .join(JoinRelType.INNER, builder.literal(true)) + .join(JoinRelType.LEFT, + builder.equals( + builder.field(2, 0, 0), // Left (tblA), field 0 + builder.field(2, 1, 0) // Right (tblC), field 0 + ) + ) + .aggregate( + builder.groupKey(builder.field(0), builder.field(2)) // tblA.f0 and tblC.f0 + ) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.AggregateJoinJoinRemove.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running AggregateJoinJoinRemove test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinRemoveTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinRemoveTest.java new file mode 100644 index 0000000..38ed45a --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateJoinRemoveTest.java @@ -0,0 +1,82 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class AggregateJoinRemoveTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // 所有字段都用INTEGER,避免VARCHAR charset问题 + var sourceA1 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceA1); + + var sourceA2 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceA2); + + var sourceB1 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceB1); + + var sourceB2 = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceB2); + + // Build scans + var scanA1 = builder.scan(sourceA1.getName()).build(); + var scanA2 = builder.scan(sourceA2.getName()).build(); + var scanB1 = builder.scan(sourceB1.getName()).build(); + var scanB2 = builder.scan(sourceB2.getName()).build(); + + // Before: (scanA1 JOIN scanA2) LEFT JOIN (scanB1 JOIN scanB2), then aggregate + var before = builder + .push(scanA1) + .push(scanA2) + .join(JoinRelType.INNER, builder.literal(true)) + .push(scanB1) + .push(scanB2) + .join(JoinRelType.INNER, builder.literal(true)) + .join(JoinRelType.LEFT, + builder.equals( + builder.field(2, 0, 0), + builder.field(2, 1, 0) + ) + ) + .aggregate( + builder.groupKey(builder.field(0)) + ) + .build(); + + // After: just (scanA1 JOIN scanA2) then aggregate + var after = builder + .push(scanA1) + .push(scanA2) + .join(JoinRelType.INNER, builder.literal(true)) + .aggregate( + builder.groupKey(builder.field(0)) + ) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.AggregateJoinRemove.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running AggregateJoinRemove test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectConstantToDummyJoinTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectConstantToDummyJoinTest.java new file mode 100644 index 0000000..bbe9c59 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectConstantToDummyJoinTest.java @@ -0,0 +1,73 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class AggregateProjectConstantToDummyJoinTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create source table with 2 fields + var sourceTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // field 0 + Tuple.of(RelType.fromString("INTEGER", true), false) // field 1 + )); + builder.addTable(sourceTable); + + var scan = builder.scan(sourceTable.getName()).build(); + + // Before: source -> project adds literals -> aggregate groups by literals + // Project: [field(0), literal(true), literal(2024), field(1)] + // Aggregate: group by [field(1)=true, field(2)=2024, field(0)], avg(field(3)) + var before = builder + .push(scan) + .project( + builder.field(0), + builder.literal(true), + builder.literal(2024), + builder.field(1) + ) + .aggregate( + builder.groupKey(builder.field(1), builder.field(2), builder.field(0)), + builder.avg(builder.field(3)) + ) + .build(); + + // After: source -> join with dummy values table -> project replaces literals with dummy fields -> aggregate + // Join adds dummy table with [true, 2024] + // After join: [source.f0, source.f1, dummy.col0=true, dummy.col1=2024] + // Project: [field(0)=source.f0, field(2)=dummy.true, field(3)=dummy.2024, field(1)=source.f1] + // Aggregate: group by [field(1)=dummy.true, field(2)=dummy.2024, field(0)=source.f0], avg(field(3)) + var after = builder + .push(scan) + .values(new String[]{"col0", "col1"}, true, 2024) + .join(JoinRelType.INNER, builder.literal(true)) + .project( + builder.field(0), // source.field(0) + builder.field(2), // dummy.col0 (true) + builder.field(3), // dummy.col1 (2024) + builder.field(1) // source.field(1) + ) + .aggregate( + builder.groupKey(builder.field(1), builder.field(2), builder.field(0)), + builder.avg(builder.field(3)) + ) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.AggregateProjectConstantToDummyJoin.Config.DEFAULT.toRule(), 1 + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running AggregateProjectConstantToDummyJoin test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectMergeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectMergeTest.java new file mode 100644 index 0000000..dd4dfcc --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/AggregateProjectMergeTest.java @@ -0,0 +1,72 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class AggregateProjectMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var empTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("DATE", true), false), + Tuple.of(RelType.fromString("DECIMAL", true), false), + Tuple.of(RelType.fromString("DECIMAL", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(empTable); + + var empScan = builder.scan(empTable.getName()).build(); + + var before = builder + .push(empScan) + .project( + builder.field(7), + builder.field(0), + builder.field(5) + ) + .aggregate( + builder.groupKey(builder.field(0), builder.field(1)), + builder.sum(builder.field(2)) + ) + .project( + builder.field(0), + builder.field(2), + builder.field(1) + ) + .build(); + + var after = builder + .push(empScan) + .aggregate( + builder.groupKey(builder.field(0), builder.field(7)), + builder.sum(builder.field(5)) + ) + .project( + builder.field(1), + builder.field(0), + builder.field(2) + ) + .project( + builder.field(0), + builder.field(2), + builder.field(1) + ) + .build(); + + var runner = CalciteTester.loadRules(org.qed.Backends.Calcite.Generated.AggregateProjectMerge.Config.DEFAULT.toRule(), org.qed.Backends.Calcite.Generated.ProjectMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running AggregateProjectMerge comprehensive test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/FilterAggregateTransposeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/FilterAggregateTransposeTest.java new file mode 100644 index 0000000..4754cce --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/FilterAggregateTransposeTest.java @@ -0,0 +1,58 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterAggregateTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var sourceTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("DECIMAL", true), false) + )); + builder.addTable(sourceTable); + + var sourceScan = builder.scan(sourceTable.getName()).build(); + + var before = builder + .push(sourceScan) + .aggregate( + builder.groupKey(builder.field(1), builder.field(2)), + builder.sum(builder.field(3)) + ) + .filter(builder.call( + builder.genericPredicateOp("pred", true), + builder.field(0), builder.field(1) + )) + .build(); + var after = builder + .push(sourceScan) + .filter(builder.call( + builder.genericPredicateOp("pred", true), + builder.field(1), builder.field(2) + )) + .aggregate( + builder.groupKey(builder.field(1), builder.field(2)), + builder.sum(builder.field(3)) + ) + .build(); + + var runner = CalciteTester.loadRules( + org.qed.Backends.Calcite.Generated.FilterAggregateTranspose.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterAggregateTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/FilterIntoJoinTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/FilterIntoJoinTest.java new file mode 100644 index 0000000..368013e --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/FilterIntoJoinTest.java @@ -0,0 +1,41 @@ + +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterIntoJoinTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), + builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.FilterIntoJoin.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterIntoJoin test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/FilterMergeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/FilterMergeTest.java new file mode 100644 index 0000000..d449aa1 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/FilterMergeTest.java @@ -0,0 +1,36 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) + .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("inner", true), builder.fields()), + builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.FilterMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/FilterProjectTransposeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/FilterProjectTransposeTest.java new file mode 100644 index 0000000..d4899fb --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/FilterProjectTransposeTest.java @@ -0,0 +1,59 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterProjectTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan = builder.scan(table.getName()).build(); + + var before = builder + .push(scan) + .project( + builder.field(1), + builder.field(2) + ) + .filter( + builder.and( + builder.greaterThan(builder.field(0), builder.literal(50000)), + builder.equals(builder.field(1), builder.literal(5)) + ) + ) + .build(); + + var after = builder + .push(scan) + .filter( + builder.and( + builder.greaterThan(builder.field(1), builder.literal(50000)), + builder.equals(builder.field(2), builder.literal(5)) + ) + ) + .project( + builder.field(1), + builder.field(2) + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.FilterProjectTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterProjectTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceFalseTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceFalseTest.java new file mode 100644 index 0000000..a9d5347 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceFalseTest.java @@ -0,0 +1,46 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterReduceFalseTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create a simple table + var sourceTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceTable); + + var scan = builder.scan(sourceTable.getName()).build(); + + // Before: scan + filter(false) + var before = builder + .push(scan) + .filter(builder.literal(false)) + .build(); + + // After: scan + empty() + var after = builder + .push(scan) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.FilterReduceFalse.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterReduceFalse test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceTrueTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceTrueTest.java new file mode 100644 index 0000000..3ca0bdf --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/FilterReduceTrueTest.java @@ -0,0 +1,45 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterReduceTrueTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create a simple table + var sourceTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(sourceTable); + + var scan = builder.scan(sourceTable.getName()).build(); + + // Before: scan + filter(true) + var before = builder + .push(scan) + .filter(builder.literal(true)) + .build(); + + // After: just scan (filter is removed) + var after = builder + .push(scan) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.FilterReduceTrue.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterReduceTrue test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/FilterSetOpTransposeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/FilterSetOpTransposeTest.java new file mode 100644 index 0000000..6aae255 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/FilterSetOpTransposeTest.java @@ -0,0 +1,37 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterSetOpTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + + var union = builder.push(scan1).push(scan2).union(false).build(); + var before = builder.push(union).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + + var filteredScan1 = builder.push(scan1).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var filteredScan2 = builder.push(scan2).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var after = builder.push(filteredScan1).push(filteredScan2).union(false).build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.FilterSetOpTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterSetOpTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/IntersectMergeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/IntersectMergeTest.java new file mode 100644 index 0000000..a12daf6 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/IntersectMergeTest.java @@ -0,0 +1,36 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class IntersectMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + var firstIntersect = builder.push(scan1).push(scan2).intersect(false).build(); + var before = builder.push(firstIntersect).push(scan3).intersect(false).build(); + + var after = builder.push(scan1).push(scan2).push(scan3).intersect(false, 3).build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.IntersectMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running IntersectMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/JoinAddRedundantSemiJoinTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/JoinAddRedundantSemiJoinTest.java new file mode 100644 index 0000000..122002c --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/JoinAddRedundantSemiJoinTest.java @@ -0,0 +1,73 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinAddRedundantSemiJoinTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create left and right tables + var leftTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(leftTable); + + var rightTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(rightTable); + + var scanLeft = builder.scan(leftTable.getName()).build(); + var scanRight = builder.scan(rightTable.getName()).build(); + + // Before: left INNER JOIN right ON left.f0 = right.f0 + var before = builder + .push(scanLeft) + .push(scanRight) + .join(JoinRelType.INNER, + builder.equals( + builder.field(2, 0, 0), // left.f0 + builder.field(2, 1, 0) // right.f0 + ) + ) + .build(); + + // After: (left SEMI JOIN right ON left.f0 = right.f0) INNER JOIN right ON left.f0 = right.f0 + var after = builder + .push(scanLeft) + .push(scanRight) + .join(JoinRelType.SEMI, + builder.equals( + builder.field(2, 0, 0), // left.f0 + builder.field(2, 1, 0) // right.f0 + ) + ) + .push(scanRight) + .join(JoinRelType.INNER, + builder.equals( + builder.field(2, 0, 0), // left.f0 (from semi join result) + builder.field(2, 1, 0) // right.f0 (from new right scan) + ) + ) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.JoinAddRedundantSemiJoin.Config.DEFAULT.toRule(), 1 + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinAddRedundantSemiJoin test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/JoinCommuteTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/JoinCommuteTest.java new file mode 100644 index 0000000..a832a67 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/JoinCommuteTest.java @@ -0,0 +1,77 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinCommuteTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var empTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("DATE", true), false), + Tuple.of(RelType.fromString("DECIMAL", true), false), + Tuple.of(RelType.fromString("DECIMAL", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(empTable); + + var deptTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false) + )); + builder.addTable(deptTable); + + var empScan = builder.scan(empTable.getName()).build(); + var deptScan = builder.scan(deptTable.getName()).build(); + + var before = builder + .push(empScan) + .push(deptScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 0, 7), + builder.field(2, 1, 0) + )) + .build(); + + var after = builder + .push(deptScan) + .push(empScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 1, 7), + builder.field(2, 0, 0) + )) + .project( + builder.field(3), + builder.field(4), + builder.field(5), + builder.field(6), + builder.field(7), + builder.field(8), + builder.field(9), + builder.field(10), + builder.field(0), + builder.field(1), + builder.field(2) + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.JoinCommute.Config.DEFAULT.toRule(), 1); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinCommute test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/JoinConditionPushTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/JoinConditionPushTest.java new file mode 100644 index 0000000..ff2af1f --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/JoinConditionPushTest.java @@ -0,0 +1,81 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinConditionPushTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var empTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(empTable); + + var deptTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("VARCHAR", true), false) + )); + builder.addTable(deptTable); + + var empScan = builder.scan(empTable.getName()).build(); + var deptScan = builder.scan(deptTable.getName()).build(); + + var before = builder + .push(empScan) + .push(deptScan) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("joinCond", true), + builder.field(2, 0, 2), + builder.field(2, 1, 0) + ), + builder.call(builder.genericPredicateOp("leftCond", true), + builder.field(2, 0, 0) + ), + builder.call(builder.genericPredicateOp("rightCond", true), + builder.field(2, 1, 0) + ) + )) + .build(); + + var after = builder + .push( + builder.push(empScan) + .filter(builder.call(builder.genericPredicateOp("leftCond", true), + builder.field(0) + )) + .build() + ) + .push( + builder.push(deptScan) + .filter(builder.call(builder.genericPredicateOp("rightCond", true), + builder.field(0) + )) + .build() + ) + .join(JoinRelType.INNER, + builder.call(builder.genericPredicateOp("joinCond", true), + builder.field(2, 0, 2), + builder.field(2, 1, 0) + ) + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.JoinConditionPush.Config.DEFAULT.toRule(), 1); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinConditionPush test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/JoinExtractFilterTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/JoinExtractFilterTest.java new file mode 100644 index 0000000..a2ec839 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/JoinExtractFilterTest.java @@ -0,0 +1,46 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinExtractFilterTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("VARCHAR", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + var before = builder.push(leftScan) + .push(rightScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), + builder.field(2, 0, 0), builder.field(2, 1, 0))) + .build(); + + var trueJoin = builder.push(leftScan) + .push(rightScan) + .join(JoinRelType.INNER, builder.literal(true)) + .build(); + + var after = builder.push(trueJoin) + .filter(builder.call(builder.genericPredicateOp("join", true), builder.field(0), builder.field(1))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.JoinExtractFilter.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinExtractFilter test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/JoinPushTransitivePredicatesTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/JoinPushTransitivePredicatesTest.java new file mode 100644 index 0000000..5d0c97c --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/JoinPushTransitivePredicatesTest.java @@ -0,0 +1,42 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinPushTransitivePredicatesTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .scan(table.getName()).join(JoinRelType.INNER,builder.call(builder.genericPredicateOp("cond1", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("cond2", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("cond1", true), builder.joinFields()), + builder.call(builder.genericPredicateOp("cond2", true), builder.joinFields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.JoinPushTransitivePredicates.Config.DEFAULT.toRule()); + + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectFilterTranspose test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceFalseTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceFalseTest.java new file mode 100644 index 0000000..49e0fd1 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceFalseTest.java @@ -0,0 +1,42 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinReduceFalseTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var before = builder.scan(leftTable.getName()) + .scan(rightTable.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), + builder.literal(false))) + .build(); + + var after = builder.scan(leftTable.getName()) + .scan(rightTable.getName()) + .join(JoinRelType.INNER, builder.literal(false)) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.JoinReduceFalse.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinReduceFalse test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceTrueTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceTrueTest.java new file mode 100644 index 0000000..c4be30b --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/JoinReduceTrueTest.java @@ -0,0 +1,42 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinReduceTrueTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var before = builder.scan(leftTable.getName()) + .scan(rightTable.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), + builder.literal(true))) + .build(); + + var after = builder.scan(leftTable.getName()) + .scan(rightTable.getName()) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.JoinReduceTrue.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinReduceTrue test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/MinusMergeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/MinusMergeTest.java new file mode 100644 index 0000000..71ef13e --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/MinusMergeTest.java @@ -0,0 +1,34 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class MinusMergeTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + var before = builder.push(scan1).push(scan2).minus(false, 2).push(scan3).minus(false, 2).build(); + + var union = builder.push(scan2).push(scan3).union(false).build(); + var after = builder.push(scan1).push(union).minus(false, 2).build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.MinusMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + public static void main(String[] args) { + System.out.println("Running MinusMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/ProjectAggregateMergeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/ProjectAggregateMergeTest.java new file mode 100644 index 0000000..9222d7f --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/ProjectAggregateMergeTest.java @@ -0,0 +1,74 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class ProjectAggregateMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create source table with 4 fields + var sourceTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // field 0 - group key + Tuple.of(RelType.fromString("INTEGER", true), false), // field 1 - for sum/max + Tuple.of(RelType.fromString("INTEGER", true), false), // field 2 - for avg + Tuple.of(RelType.fromString("INTEGER", true), false) // field 3 - for count + )); + builder.addTable(sourceTable); + + var scan = builder.scan(sourceTable.getName()).build(); + + // Before: Aggregate with 4 agg calls, then project uses only 2 of them + // Aggregate: group by field(0), sum(field(1)), avg(field(2)), count(field(3)), max(field(1)) + // Result fields: [0=group, 1=sum, 2=avg, 3=count, 4=max] + // Project: uses only field(0), field(1), field(3) - skips avg and max + var before = builder + .push(scan) + .aggregate( + builder.groupKey(builder.field(0)), + builder.sum(false, "agg1", builder.field(1)), // Will be used - becomes field 1 + builder.avg(builder.field(2)), // Will be unused - field 2 + builder.count(false, "agg3", builder.field(3)), // Will be used - becomes field 3 + builder.max(builder.field(1)) // Will be unused - field 4 + ) + .project( + builder.field(0), // group key + builder.field(1), // agg1 (sum) + builder.field(3) // agg3 (count) - skips field 2 (avg) and field 4 (max) + ) + .build(); + + // After: Aggregate with only 2 used agg calls, project adjusted + // Aggregate: group by field(0), sum(field(1)), count(field(3)) + // Result fields: [0=group, 1=sum, 2=count] + // Project: field(0), field(1), field(2) + var after = builder + .push(scan) + .aggregate( + builder.groupKey(builder.field(0)), + builder.sum(false, "agg1", builder.field(1)), // field 1 + builder.count(false, "agg3", builder.field(3)) // field 2 (was 3) + ) + .project( + builder.field(0), // group key + builder.field(1), // agg1 (sum) + builder.field(2) // agg3 (count) - adjusted index + ) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.ProjectAggregateMerge.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectAggregateMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/ProjectFilterTransposeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/ProjectFilterTransposeTest.java new file mode 100644 index 0000000..2c8a522 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/ProjectFilterTransposeTest.java @@ -0,0 +1,59 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class ProjectFilterTransposeTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan = builder.scan(table.getName()).build(); + + var before = builder + .push(scan) + .filter( + builder.and( + builder.greaterThan(builder.field(1), builder.literal(50000)), + builder.equals(builder.field(2), builder.literal(5)) + ) + ) + .project( + builder.field(1), + builder.field(2) + ) + .build(); + + var after = builder + .push(scan) + .project( + builder.field(1), + builder.field(2) + ) + .filter( + builder.and( + builder.greaterThan(builder.field(0), builder.literal(50000)), + builder.equals(builder.field(1), builder.literal(5)) + ) + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.ProjectFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectFilterTranspose test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/ProjectMergeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/ProjectMergeTest.java new file mode 100644 index 0000000..b2b293e --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/ProjectMergeTest.java @@ -0,0 +1,60 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class ProjectMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create a source table with 3 fields + var sourceTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // field 0 + Tuple.of(RelType.fromString("INTEGER", true), false), // field 1 + Tuple.of(RelType.fromString("INTEGER", true), false) // field 2 + )); + builder.addTable(sourceTable); + + var scan = builder.scan(sourceTable.getName()).build(); + + // Before: scan → project(f0, f1) → project(field(1), field(0)) + // Inner projection selects f0, f1 from source + // Outer projection reorders them to f1, f0 + var before = builder + .push(scan) + .project( + builder.field(0), // f0 + builder.field(1) // f1 + ) + .project( + builder.field(1), // selects field 1 from inner projection (which is f1) + builder.field(0) // selects field 0 from inner projection (which is f0) + ) + .build(); + + // After: scan → project(f1, f0) + // Single merged projection directly selects f1, f0 from source + var after = builder + .push(scan) + .project( + builder.field(1), // f1 directly from source + builder.field(0) // f0 directly from source + ) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.ProjectMerge.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyFilterTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyFilterTest.java new file mode 100644 index 0000000..88a78ae --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyFilterTest.java @@ -0,0 +1,46 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneEmptyFilterTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .filter( + builder.call( + builder.genericPredicateOp("filter_cond", true), + builder.fields() + ) + ) + .empty() + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.PruneEmptyFilter.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyFilter test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyIntersectTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyIntersectTest.java new file mode 100644 index 0000000..ba28a30 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyIntersectTest.java @@ -0,0 +1,58 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneEmptyIntersectTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create two tables with compatible schemas + var tableA = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(tableA); + + var tableB = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(tableB); + + var scanA = builder.scan(tableA.getName()).build(); + var scanB = builder.scan(tableB.getName()).build(); + + // Before: A INTERSECT DISTINCT (Empty B) + var before = builder + .push(scanA) + .push(scanB) + .empty() + .intersect(false, 2) + .build(); + + // After: (Empty A) INTERSECT DISTINCT (Empty B) + var after = builder + .push(scanA) + .empty() + .push(scanB) + .empty() + .intersect(false, 2) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.PruneEmptyIntersect.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyIntersect test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyMinusTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyMinusTest.java new file mode 100644 index 0000000..c410414 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyMinusTest.java @@ -0,0 +1,51 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; +import org.apache.calcite.rel.RelNode; + +public class PruneEmptyMinusTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + RelNode scanA = builder + .scan(table.getName()) + .build(); + + RelNode scanB = builder + .scan(table.getName()) + .build(); + + RelNode before = builder + .push(scanA) + .push(scanB) + .minus(false) + .empty() + .build(); + + RelNode after = builder + .push(scanA) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.PruneEmptyMinus.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyMinus test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyProjectTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyProjectTest.java new file mode 100644 index 0000000..58db9f1 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyProjectTest.java @@ -0,0 +1,41 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneEmptyProjectTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .empty() + .project(builder.field(0)) + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.PruneEmptyProject.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyProject test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyUnionTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyUnionTest.java new file mode 100644 index 0000000..014e986 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/PruneEmptyUnionTest.java @@ -0,0 +1,62 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; +import org.apache.calcite.rel.RelNode; + +public class PruneEmptyUnionTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + RelNode scanA = builder + .scan(table.getName()) + .build(); + + RelNode emptyA = builder + .push(scanA) + .empty() + .build(); + + RelNode scanB = builder + .scan(table.getName()) + .build(); + + RelNode emptyB = builder + .push(scanB) + .empty() + .build(); + + RelNode before = builder + .push(scanA) + .push(scanB) + .union(false) + .empty() + .build(); + + RelNode after = builder + .push(emptyA) + .push(emptyB) + .union(false) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.PruneEmptyUnion.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyMinus test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/SemiJoinFilterTransposeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/SemiJoinFilterTransposeTest.java new file mode 100644 index 0000000..3d73f9c --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/SemiJoinFilterTransposeTest.java @@ -0,0 +1,48 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class SemiJoinFilterTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + builder.push(leftScan); + builder.push(rightScan); + var joinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + var semiJoin = builder.join(JoinRelType.SEMI, joinPredicate).build(); + builder.push(semiJoin); + var filterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + var before = builder.filter(filterPredicate).build(); + + builder.push(leftScan); + var leftFilterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + var filteredLeft = builder.filter(leftFilterPredicate).build(); + builder.push(filteredLeft); + builder.push(rightScan); + var afterJoinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + var after = builder.join(JoinRelType.SEMI, afterJoinPredicate).build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.SemiJoinFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running SemiJoinFilterTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/UnionMergeTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/UnionMergeTest.java new file mode 100644 index 0000000..68fdb9d --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/UnionMergeTest.java @@ -0,0 +1,36 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class UnionMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + var firstUnion = builder.push(scan1).push(scan2).union(false).build(); + var before = builder.push(firstUnion).push(scan3).union(false).build(); + + var after = builder.push(scan1).push(scan2).push(scan3).union(false, 3).build(); + + var runner = CalciteTester.loadRule(org.qed.Backends.Calcite.Generated.UnionMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running UnionMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/UnionPullUpConstantsTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/UnionPullUpConstantsTest.java new file mode 100644 index 0000000..df87db6 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/UnionPullUpConstantsTest.java @@ -0,0 +1,87 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class UnionPullUpConstantsTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create source tables with 3 fields each + var leftTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(leftTable); + + var rightTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(rightTable); + + var scanLeft = builder.scan(leftTable.getName()).build(); + var scanRight = builder.scan(rightTable.getName()).build(); + + // Before: both sides project with constant in middle + // Left: [field(0), literal("ACTIVE"), field(2)] + // Right: [field(0), literal("ACTIVE"), field(2)] + // Union: all 3 columns including the constant + var before = builder + .push(scanLeft) + .project( + builder.field(0), + builder.alias(builder.literal("ACTIVE"), "status"), + builder.field(2) + ) + .push(scanRight) + .project( + builder.field(0), + builder.alias(builder.literal("ACTIVE"), "status"), + builder.field(2) + ) + .union(true, 2) + .build(); + + // After: pull constant to top + // Left reduced: [field(0), field(2)] + // Right reduced: [field(0), field(2)] + // Union: only 2 columns + // Top project adds constant back: [field(0), literal("ACTIVE"), field(1)] + var after = builder + .push(scanLeft) + .project( + builder.field(0), + builder.field(2) + ) + .push(scanRight) + .project( + builder.field(0), + builder.field(2) + ) + .union(true, 2) + .project( + builder.field(0), + builder.alias(builder.literal("ACTIVE"), "status"), + builder.field(1) + ) + .build(); + + var runner = CalciteTester.loadRules( + org.qed.Backends.Calcite.Generated.UnionPullUpConstants.Config.DEFAULT.toRule(), org.qed.Backends.Calcite.Generated.ProjectMerge.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running UnionPullUpConstants test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Calcite/Tests/UnionToDistinctTest.java b/src/main/java/org/qed/Backends/Calcite/Tests/UnionToDistinctTest.java new file mode 100644 index 0000000..bc5399b --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/Tests/UnionToDistinctTest.java @@ -0,0 +1,58 @@ +package org.qed.Backends.Calcite.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Backends.Calcite.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class UnionToDistinctTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create left and right tables with same schema + var leftTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(leftTable); + + var rightTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(rightTable); + + var scanLeft = builder.scan(leftTable.getName()).build(); + var scanRight = builder.scan(rightTable.getName()).build(); + + // Before: left UNION right (UNION DISTINCT, all=false) + var before = builder + .push(scanLeft) + .push(scanRight) + .union(false, 2) // UNION DISTINCT + .build(); + + // After: (left UNION ALL right) then aggregate (group by all fields) + var after = builder + .push(scanLeft) + .push(scanRight) + .union(true, 2) // UNION ALL + .aggregate( + builder.groupKey(builder.field(0), builder.field(1)) + ) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Backends.Calcite.Generated.UnionToDistinct.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running UnionToDistinct test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java b/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java new file mode 100644 index 0000000..276387d --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java @@ -0,0 +1,1657 @@ +package org.qed.Backends.Cockroach; +import kala.collection.Seq; +import kala.collection.immutable.ImmutableMap; +import org.qed.CodeGenerator; +import org.qed.RelRN; +import org.qed.RexRN; +import java.util.concurrent.atomic.AtomicInteger; +public class CockroachGenerator implements CodeGenerator { + private static String b(String var) { + return "$" + var + ":*"; + } + private static String r(String var) { + return "$" + var; + } + private static String N(String type, String... children) { + StringBuilder sb = new StringBuilder("(").append(type); + for (String c : children) sb.append("\n ").append(c); + return sb.append("\n)").toString(); + } + private static String filtersBoundBy(String condVar, String ref) { + return "(FiltersBoundBy " + r(condVar) + " " + ref + ")"; + } + private static boolean flag(Env env, String key) { + return env.bindings().containsKey(key); + } + private static String get(Env env, String key) { + return env.bindings().get(key); + } + private static String get(Env env, String key, String def) { + return env.bindings().getOrDefault(key, def); + } + private static String joinType(org.apache.calcite.rel.core.JoinRelType ty) { + return switch (ty) { + case INNER -> "InnerJoin"; + case LEFT -> "LeftJoin"; + case RIGHT -> "RightJoin"; + case FULL -> "FullJoin"; + case SEMI -> "SemiJoin"; + case ANTI -> "AntiJoin"; + default -> "InnerJoin"; + }; + } + private String getJoinType(org.apache.calcite.rel.core.JoinRelType ty) { + return joinType(ty); + } + @Override + public Env preMatch(String rulename) { + return Env.empty(rulename); + } + @Override + public Env preTransform(Env env) { + String p = env.pattern(); + if (p == null) return env; + int idx = p.indexOf("(HasZeroRows $"); + if (idx < 0) return env; + int start = idx + "(HasZeroRows $".length(); + int end = p.indexOf(")", start); + if (end <= start) return env; + String var = p.substring(start, end).trim(); + if (var.isEmpty()) return env; + return env.addBinding("hasZeroRows", "true").addBinding("zeroInput", var); + } + @Override + public Env onMatchScan(Env env, RelRN.Scan scan) { + String var = env.generateVar("input"); + return env.addBinding(scan.name(), var).focus(b(var)); + } + @Override + public Env onMatchFilter(Env env, RelRN.Filter filter) { + if (filter.source() instanceof RelRN.Project) { + return matchFilterOverProject(env); + } + if (filter.source() instanceof RelRN.Union) { + return matchFilterOverUnion(env); + } + Env sourceEnv = onMatch(env, filter.source()); + String sourcePattern = sourceEnv.current(); + Env condEnv = onMatch(sourceEnv, filter.cond()); + if (filter.source() instanceof RelRN.Empty) { + String inputVar = condEnv.generateVar("input"); + String filtersVar = condEnv.generateVar("filters"); + String pattern = N("Select", b(inputVar) + " & (HasZeroRows " + r(inputVar) + ")", b(filtersVar) ); + return condEnv .addBinding("isPruneEmptyFilter", "true") .addBinding("pruneEmptyInput", inputVar) .setPattern(pattern).focus(pattern); + } + if (filter.cond() instanceof RexRN.True) { + String pattern = N("Select", sourcePattern, "[]"); + return condEnv.setPattern(pattern).focus(pattern); + } + if (filter.cond() instanceof RexRN.False) { + String onVar = condEnv.generateVar("on"); + Env onEnv = condEnv.addBinding("on", onVar); + String itemVar = onEnv.generateVar("item"); + Env itemEnv = onEnv.addBinding("item", itemVar); + String listPat = r(onVar) + ":[\n" + " ...\n" + " " + r(itemVar) + ":(FiltersItem (False))\n" + " ...\n" + " ]"; + String pattern = N("Select", sourcePattern, listPat); + return itemEnv.setPattern(pattern).focus(pattern); + } + String condPattern = condEnv.current(); + if (filter.source() instanceof RelRN.Aggregate) { + String condVarName = extractVar(condPattern); + String privateVar = get(sourceEnv, "aggregate_private", null); + if (condVarName != null && privateVar != null) { + condPattern = condPattern + " & " + filtersBoundBy(condVarName, "(GroupingCols " + r(privateVar) + ")"); + } + } + String filterCondVarName = extractVar(condEnv.current()); + if (filterCondVarName != null) { + condEnv = condEnv.addBinding("filterCondVar", filterCondVarName); + } + String pattern = N("Select", sourcePattern, condPattern); + return condEnv.setPattern(pattern).focus(pattern); + } + private Env matchFilterOverProject(Env env) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String projVar = inputEnv.generateVar("proj"); + Env projEnv = inputEnv.addBinding("proj", projVar); + String passthroughVar = projEnv.generateVar("passthrough"); + Env passEnv = projEnv.addBinding("passthrough", passthroughVar); + String condVar = passEnv.generateVar("cond"); + Env condEnv = passEnv.addBinding("cond", condVar); + String inputColsVar = condEnv.generateVar("inputCols"); + Env resultEnv = condEnv.addBinding("inputCols", inputColsVar); + String projectPat = "(Project\n" + " " + b(inputVar) + "\n" + " " + b(projVar) + "\n" + " " + b(passthroughVar) + "\n" + ")"; + String condLine = b(condVar) + " &\n" + " " + filtersBoundBy(condVar, r(inputColsVar) + ":(OutputCols " + r(inputVar) + ")"); + String pattern = N("Select", projectPat, condLine); + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchFilterOverUnion(Env env) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String leftVar = inputEnv.generateVar("left"); + Env leftEnv = inputEnv.addBinding("left", leftVar); + String rightVar = leftEnv.generateVar("right"); + Env rightEnv = leftEnv.addBinding("right", rightVar); + String colmapVar = rightEnv.generateVar("colmap"); + Env colmapEnv = rightEnv.addBinding("colmap", colmapVar); + String filterVar = colmapEnv.generateVar("filter"); + Env filterEnv = colmapEnv.addBinding("filter", filterVar); + String itemVar = filterEnv.generateVar("item"); + Env itemEnv = filterEnv.addBinding("item", itemVar); + String inputColsVar = itemEnv.generateVar("inputCols"); + Env resultEnv = itemEnv.addBinding("inputCols", inputColsVar) .addBinding("isFilterSetOpTranspose", "true"); + String unionPat = r(inputVar) + ":(Union " + r(leftVar) + ":* " + r(rightVar) + ":* " + r(colmapVar) + ":*)"; + String listPat = r(filterVar) + ":[\n" + " ...\n" + " " + r(itemVar) + ":* &\n" + " (CanMapOnSetOp " + r(itemVar) + ") &\n" + " (IsBoundBy " + r(itemVar) + " " + r(inputColsVar) + ":(OutputCols " + r(inputVar) + "))\n" + " ...\n" + " ]"; + String pattern = N("Select", unionPat, listPat); + return resultEnv.setPattern(pattern).focus(pattern); + } + public Env onMatchProject(Env env, RelRN.Project project) { + if (project.source() instanceof RelRN.Empty) { + return matchProjectOverEmpty(env); + } + if (project.source() instanceof RelRN.Aggregate aggregate) { + return matchProjectOverAggregate(env, aggregate); + } + if (project.source() instanceof RelRN.Project) { + return matchProjectOverProject(env, project); + } + return matchProjectGeneral(env, project); + } + private Env matchProjectOverEmpty(Env env) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("zeroInput", inputVar) .addBinding("hasZeroRows", "true"); + String projectionsVar = inputEnv.generateVar("projections"); + Env projEnv = inputEnv.addBinding("projections", projectionsVar); + String passthroughVar = projEnv.generateVar("passthrough"); + Env passEnv = projEnv.addBinding("passthrough", passthroughVar); + String pattern = N("Project", b(inputVar) + " & (HasZeroRows " + r(inputVar) + ")", b(projectionsVar), b(passthroughVar) ); + return passEnv.setPattern(pattern).focus(pattern); + } + private Env matchProjectOverAggregate(Env env, RelRN.Aggregate aggregate) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + Env aggInputEnv = onMatch(inputEnv, aggregate.source()); + String innerInputVar = aggInputEnv.current().replace("$", "").replace(":*", "").trim(); + Env aggBound = aggInputEnv.addBinding("innerInput", innerInputVar); + Env aggsEnv = onMatchAggCalls(aggBound, aggregate.aggCalls()); + String aggregationsVar = aggsEnv.generateVar("aggregations"); + Env aggsBindEnv = aggsEnv.addBinding("aggregations", aggregationsVar); + Env groupEnv = onMatchGroupSet(aggsBindEnv, aggregate.groupSet()); + String groupingPrivateVar = groupEnv.generateVar("groupingPrivate"); + Env gpBindEnv = groupEnv.addBinding("groupingPrivate", groupingPrivateVar); + Env projEnv = onMatch(gpBindEnv, aggregate.source()); + String projectionsVar = projEnv.generateVar("projections"); + Env projBindEnv = projEnv.addBinding("projections", projectionsVar); + String passthroughVar = projBindEnv.generateVar("passthrough"); + Env passEnv = projBindEnv.addBinding("passthrough", passthroughVar); + String neededVar = passEnv.generateVar("needed"); + Env resultEnv = passEnv.addBinding("needed", neededVar); + String groupByPat = "(GroupBy\n" + " " + b(innerInputVar) + "\n" + " " + b(aggregationsVar) + "\n" + " " + b(groupingPrivateVar) + "\n" + " )"; + String passCond = b(passthroughVar) + " &\n" + " (CanPruneAggCols\n" + " " + r(aggregationsVar) + "\n" + " " + r(neededVar) + ":(UnionCols\n" + " (ProjectionOuterCols " + r(projectionsVar) + ")\n" + " " + r(passthroughVar) + "\n" + " )\n" + " )"; + String pattern = "(Project\n" + " " + r(inputVar) + ":(" + groupByPat.trim() + "\n" + " " + b(projectionsVar) + "\n" + " " + passCond + "\n" + ")"; + pattern = "(Project\n" + " " + r(inputVar) + ":(GroupBy\n" + " " + b(innerInputVar) + "\n" + " " + b(aggregationsVar) + "\n" + " " + b(groupingPrivateVar) + "\n" + " )\n" + " " + b(projectionsVar) + "\n" + " " + b(passthroughVar) + " &\n" + " (CanPruneAggCols\n" + " " + r(aggregationsVar) + "\n" + " " + r(neededVar) + ":(UnionCols\n" + " (ProjectionOuterCols " + r(projectionsVar) + ")\n" + " " + r(passthroughVar) + "\n" + " )\n" + " )\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchProjectOverProject(Env env, RelRN.Project project) { + RelRN.Project innerProject = (RelRN.Project) project.source(); + Env outerProjEnv = onMatch(env, project.map()); + String outerProjPat = outerProjEnv.current(); + Env innerInputEnv = onMatch(outerProjEnv, innerProject.source()); + String innerInputPat = innerInputEnv.current(); + Env innerProjEnv = onMatch(innerInputEnv, innerProject.map()); + String innerProjPat = innerProjEnv.current(); + String innerPassVar = innerProjEnv.generateVar("innerPassthrough"); + Env innerPassEnv = innerProjEnv.addBinding("innerPassthrough", innerPassVar); + String outerPassVar = innerPassEnv.generateVar("passthrough"); + Env resultEnv = innerPassEnv.addBinding("passthrough", outerPassVar); + String outerProjRef = outerProjPat.replace(":*", ""); + String innerProjRef = innerProjPat.replace(":*", ""); + String innerProjBlock = "Project\n " + innerInputPat + "\n " + innerProjPat + "\n " + b(innerPassVar); + String condLine = outerProjPat + " &\n" + " (CanMergeProjections " + outerProjRef + " " + innerProjRef + ")"; + String pattern = "(Project\n" + " $input:(" + innerProjBlock + ")\n" + " " + condLine + "\n" + " " + b(outerPassVar) + "\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchProjectGeneral(Env env, RelRN.Project project) { + Env sourceEnv = onMatch(env, project.source()); + String sourcePat = sourceEnv.current(); + Env projEnv = onMatch(sourceEnv, project.map()); + String projPat = projEnv.current(); + String passthroughVar = projEnv.generateVar("passthrough"); + Env passEnv = projEnv.addBinding("passthrough", passthroughVar); + String passthroughCond = ""; + if (project.source() instanceof RelRN.Filter) { + String condVarName = get(sourceEnv, "filterCondVar", null); + if (condVarName != null) { + passthroughCond = " & " + filtersBoundBy(condVarName, r(passthroughVar)); + } + } + String pattern = N("Project", sourcePat, projPat, b(passthroughVar) + passthroughCond ); + return passEnv.setPattern(pattern).focus(pattern); + } + @Override + public Env onMatchJoin(Env env, RelRN.Join join) { + if (join.cond() instanceof RexRN.And and && and.sources().size() == 2) { + boolean hasTrue = and.sources().stream().anyMatch(s -> s instanceof RexRN.True); + boolean hasFalse = and.sources().stream().anyMatch(s -> s instanceof RexRN.False); + RexRN other = and.sources().stream() .filter(s -> !(s instanceof RexRN.True) && !(s instanceof RexRN.False)) .findFirst().orElse(null); + if (hasTrue && other != null) { + return matchJoinReduceTrue(env, join); + } + if (hasFalse && other != null) { + return matchJoinReduceFalse(env, join); + } + } + Env leftEnv = onMatch(env, join.left()); + String leftPat = leftEnv.current(); + Env rightEnv = onMatch(leftEnv, join.right()); + String rightPat = rightEnv.current(); + Env condEnv = onMatch(rightEnv, join.cond()); + String condPat = condEnv.current(); + String privateVar = condEnv.generateVar("private"); + Env privateEnv = condEnv .addBinding("private_" + System.identityHashCode(join), privateVar) .addBinding("last_private", privateVar); + if (join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.INNER) { + String leftVar = privateEnv.generateVar("left"); + String rightVar = privateEnv.generateVar("right"); + String onVar = privateEnv.generateVar("on"); + Env bound = privateEnv .addBinding("left", leftVar) .addBinding("right", rightVar) .addBinding("on", onVar) .addBinding("private", privateVar); + if (env.rulename.equals("JoinAddRedundantSemiJoin")) { + return matchJoinAddRedundantSemiJoin(bound, privateVar, leftVar, rightVar); + } + if (env.rulename.equals("JoinCommute")) { + return matchJoinCommute(bound, privateVar, leftVar, rightVar, onVar); + } + } + if (env.rulename.equals("JoinExtractFilter") && join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.INNER && !(join.cond() instanceof RexRN.And)) { + return matchJoinExtractFilter(privateEnv, join, privateVar); + } + if (env.rulename.equals("JoinPushTransitivePredicates") && join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.INNER) { + String condVarName = extractVar(condPat); + if (condVarName != null) { + condPat = condPat + " & ^(IsFilterEmpty " + r(condVarName) + ")"; + } + } + String jType = joinType(join.ty().semantics()); + String pattern = N(jType, leftPat, rightPat, condPat, b(privateVar)); + return privateEnv.setPattern(pattern).focus(pattern); + } + private Env matchJoinReduceTrue(Env env, RelRN.Join join) { + Env leftEnv = onMatch(env, join.left()); + Env rightEnv = onMatch(leftEnv, join.right()); + String onVar = rightEnv.generateVar("on"); + Env onEnv = rightEnv.addBinding("on", onVar); + String itemVar = onEnv.generateVar("item"); + Env itemEnv = onEnv.addBinding("item", itemVar); + String privateVar = itemEnv.generateVar("private"); + Env resultEnv = itemEnv .addBinding("private_" + System.identityHashCode(join), privateVar) .addBinding("last_private", privateVar) .addBinding("joinReduceTrue", "true"); + String listPat = r(onVar) + ":[\n" + " ...\n" + " " + r(itemVar) + ":(FiltersItem (True))\n" + " ...\n" + " ]"; + String pattern = N(joinType(join.ty().semantics()), leftEnv.current(), rightEnv.current(), listPat, b(privateVar)); + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchJoinReduceFalse(Env env, RelRN.Join join) { + Env leftEnv = onMatch(env, join.left()); + Env rightEnv = onMatch(leftEnv, join.right()); + String onVar = rightEnv.generateVar("on"); + Env onEnv = rightEnv.addBinding("on", onVar); + String itemVar = onEnv.generateVar("item"); + Env itemEnv = onEnv.addBinding("item", itemVar); + String privateVar = itemEnv.generateVar("private"); + Env resultEnv = itemEnv .addBinding("private_" + System.identityHashCode(join), privateVar) .addBinding("last_private", privateVar) .addBinding("joinReduceFalse", "true"); + String listPat = r(onVar) + ":[\n" + " ...\n" + " " + r(itemVar) + ":(FiltersItem\n" + " (And * (False))\n" + " )\n" + " ...\n" + " ] &\n" + " ^(IsFilterFalse " + r(onVar) + ")"; + String pattern = N(joinType(join.ty().semantics()), leftEnv.current(), rightEnv.current(), listPat, b(privateVar)); + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchJoinAddRedundantSemiJoin(Env bound, String privateVar, String leftVar, String rightVar) { + bound = bound.addBinding("isJoinAddRedundantSemiJoin", "true"); + String filtersVar = bound.generateVar("filters"); + Env resultEnv = bound.addBinding("filters", filtersVar); + String privLine = b(privateVar) + " & ^(IsRedundantSemiJoin " + r(leftVar) + " " + r(rightVar) + " " + r(filtersVar) + ")"; + String pattern = N("InnerJoin", r(leftVar) + ":^(Values)", b(rightVar), b(filtersVar), privLine ); + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchJoinCommute(Env bound, String privateVar, String leftVar, String rightVar, String onVar) { + bound = bound.addBinding("isJoinCommute", "true"); + String privLine = b(privateVar) + " &\n" + " (CanCommuteJoin " + r(leftVar) + " " + r(rightVar) + ")"; + String pattern = N("InnerJoin", b(leftVar), b(rightVar), b(onVar), privLine); + return bound.setPattern(pattern).focus(pattern); + } + private Env matchJoinExtractFilter(Env privateEnv, RelRN.Join join, String privateVar) { + String leftVar = privateEnv.generateVar("left"); + String rightVar = privateEnv.generateVar("right"); + String onVar = privateEnv.generateVar("on"); + Env bound = privateEnv .addBinding("left", leftVar) .addBinding("right", rightVar) .addBinding("on", onVar) .addBinding("private", privateVar) .addBinding("isJoinExtractFilter", "true"); + if (join.cond() instanceof RexRN.Pred pred) { + bound = bound.addBinding(pred.operator().getName(), onVar); + } + String condLine = b(onVar) + " &\n" + " (CanExtractJoinFilter " + r(leftVar) + " " + r(rightVar) + " " + r(onVar) + ")"; + String pattern = N(joinType(join.ty().semantics()), b(leftVar), b(rightVar), condLine, b(privateVar)); + return bound.setPattern(pattern).focus(pattern); + } + @Override + public Env onMatchJoinWithSeparateConds(Env env, RelRN.JoinWithSeparateConds join) { + if (join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.INNER && join.cond() instanceof RexRN.And and && and.sources().size() > 2) { + String leftVar = env.generateVar("left"); + Env leftEnv = env.addBinding("left", leftVar); + String rightVar = leftEnv.generateVar("right"); + Env rightEnv = leftEnv.addBinding("right", rightVar); + String onVar = rightEnv.generateVar("on"); + Env onEnv = rightEnv.addBinding("on", onVar); + String privateVar = onEnv.generateVar("private"); + Env resultEnv = onEnv .addBinding("private", privateVar) .addBinding("isJoinConditionPush", "true"); + String leftLine = b(leftVar) + " & ^(HasOuterCols " + r(leftVar) + ")"; + String rightLine = b(rightVar) + " & ^(HasOuterCols " + r(rightVar) + ")"; + String onLine = b(onVar) + " &\n" + " (HasBoundConditions\n" + " " + r(onVar) + "\n" + " (OutputCols " + r(leftVar) + ")\n" + " (OutputCols " + r(rightVar) + ")\n" + " )"; + String pattern = N("InnerJoin", leftLine, rightLine, onLine, b(privateVar)); + return resultEnv.setPattern(pattern).focus(pattern); + } + Env leftEnv = onMatch(env, join.left()); + Env rightEnv = onMatch(leftEnv, join.right()); + Env condEnv = onMatch(rightEnv, join.cond()); + String privateVar = condEnv.generateVar("private"); + Env resultEnv = condEnv .addBinding("private_" + System.identityHashCode(join), privateVar) .addBinding("last_private", privateVar); + String jType = joinType(join.ty().semantics()); + String pattern = N(jType, leftEnv.current(), rightEnv.current(), condEnv.current(), b(privateVar)); + return resultEnv.setPattern(pattern).focus(pattern); + } + @Override + public Env onMatchUnion(Env env, RelRN.Union union) { + if (union.sources().size() == 2) { + RelRN l = union.sources().get(0), r = union.sources().get(1); + if (l instanceof RelRN.Empty && r instanceof RelRN.Empty) { + return matchUnionBothEmpty(env, union); + } + } + if (union.sources().size() == 2 && union.sources().get(0) instanceof RelRN.Union inner && inner.sources().size() == 2) { + return matchNestedUnion(env, union); + } + if (union.sources().size() == 2 && !union.all()) { + RelRN ls = union.sources().get(0), rs = union.sources().get(1); + if (!(ls instanceof RelRN.Union)) { + return matchDistinctUnion(env, union); + } + } + if (union.all() && union.sources().size() == 2) { + RelRN ls = union.sources().get(0), rs = union.sources().get(1); + RelRN.Project leftProj = toProjectIfPossible(ls); + RelRN.Project rightProj = toProjectIfPossible(rs); + if (leftProj != null && rightProj != null) { + return matchUnionPullUpConstants(env, union, leftProj, rightProj); + } + } + return matchUnionGeneral(env, union); + } + private static RelRN.Project toProjectIfPossible(RelRN node) { + if (node instanceof RelRN.Project p) return p; + if (node instanceof org.qed.RRuleInstances.UnionPullUpConstants.LeftProjectionWithConstants lp) + return new RelRN.Project(lp.input().field(0), lp.input()); + if (node instanceof org.qed.RRuleInstances.UnionPullUpConstants.RightProjectionWithConstants rp) + return new RelRN.Project(rp.input().field(0), rp.input()); + return null; + } + private Env matchUnionBothEmpty(Env env, RelRN.Union union) { + String leftVar = env.generateVar("left"); + Env leftEnv = env.addBinding("left", leftVar); + String rightVar = leftEnv.generateVar("right"); + Env rightEnv = leftEnv.addBinding("right", rightVar); + String privateVar = rightEnv.generateVar("private"); + Env privEnv = rightEnv.addBinding("private", privateVar); + String outColsVar = privEnv.generateVar("outCols"); + Env resultEnv = privEnv.addBinding("outCols", outColsVar) .addBinding("hasZeroRows", "true"); + String uType = union.all() ? "UnionAll" : "Union"; + String leftLine = b(leftVar) + " & (HasZeroRows " + r(leftVar) + ")"; + String rightLine = b(rightVar) + " & (HasZeroRows " + r(rightVar) + ")"; + String privLine = r(privateVar) + ":(SetPrivate * * " + r(outColsVar) + ":*)"; + String pattern = N(uType, leftLine, rightLine, privLine); + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchNestedUnion(Env env, RelRN.Union union) { + String leftLeftVar = env.generateVar("leftLeft"); + Env llEnv = env.addBinding("leftLeft", leftLeftVar); + String leftRightVar = llEnv.generateVar("leftRight"); + Env lrEnv = llEnv.addBinding("leftRight", leftRightVar); + String innerPrivateVar = lrEnv.generateVar("innerPrivate"); + Env ipEnv = lrEnv.addBinding("innerPrivate", innerPrivateVar); + String innerLeftCols = ipEnv.generateVar("innerLeftCols"); + Env ilcEnv = ipEnv.addBinding("innerLeftCols", innerLeftCols); + String innerRightCols = ilcEnv.generateVar("innerRightCols"); + Env ircEnv = ilcEnv.addBinding("innerRightCols", innerRightCols); + String innerOutCols = ircEnv.generateVar("innerOutCols"); + Env iocEnv = ircEnv.addBinding("innerOutCols", innerOutCols); + String leftVar = iocEnv.generateVar("left"); + Env lEnv = iocEnv.addBinding("left", leftVar); + String rightVar = lEnv.generateVar("right"); + Env rEnv = lEnv.addBinding("right", rightVar); + String outerPrivate = rEnv.generateVar("outerPrivate"); + Env opEnv = rEnv.addBinding("outerPrivate", outerPrivate); + String outerRightCols = opEnv.generateVar("outerRightCols"); + Env orcEnv = opEnv.addBinding("outerRightCols", outerRightCols); + String outerOutCols = orcEnv.generateVar("outerOutCols"); + Env resultEnv = orcEnv .addBinding("outerOutCols", outerOutCols) .addBinding("isUnionMerge", "true"); + String uType = union.all() ? "UnionAll" : "Union"; + String innerSetPriv = r(innerPrivateVar) + ":(SetPrivate " + b(innerLeftCols) + " " + b(innerRightCols) + " " + b(innerOutCols) + ")"; + String innerUnion = "(" + uType + "\n" + " " + b(leftLeftVar) + "\n" + " " + b(leftRightVar) + "\n" + " " + innerSetPriv + "\n" + " )"; + String outerSetPriv = r(outerPrivate) + ":(SetPrivate * " + b(outerRightCols) + " " + b(outerOutCols) + ")"; + String pattern = N(uType, r(leftVar) + ":(" + innerUnion.trim(), b(rightVar), outerSetPriv ); + pattern = "(" + uType + "\n" + " " + r(leftVar) + ":(" + uType + "\n" + " " + b(leftLeftVar) + "\n" + " " + b(leftRightVar) + "\n" + " " + r(innerPrivateVar) + ":(SetPrivate " + b(innerLeftCols) + " " + b(innerRightCols) + " " + b(innerOutCols) + ")\n" + " )\n" + " " + b(rightVar) + "\n" + " " + r(outerPrivate) + ":(SetPrivate * " + b(outerRightCols) + " " + b(outerOutCols) + ")\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchDistinctUnion(Env env, RelRN.Union union) { + Env leftEnv = onMatch(env, union.sources().get(0)); + String leftVar = leftEnv.generateVar("left"); + Env lBound = leftEnv.addBinding("left", leftVar); + Env rightEnv = onMatch(lBound, union.sources().get(1)); + String rightVar = rightEnv.generateVar("right"); + Env rBound = rightEnv.addBinding("right", rightVar); + String privateVar = rBound.generateVar("private"); + Env privEnv = rBound.addBinding("private", privateVar); + String leftCols = privEnv.generateVar("leftCols"); + Env lcEnv = privEnv.addBinding("leftCols", leftCols); + String rightCols = lcEnv.generateVar("rightCols"); + Env rcEnv = lcEnv.addBinding("rightCols", rightCols); + String outCols = rcEnv.generateVar("outCols"); + Env ocEnv = rcEnv.addBinding("outCols", outCols); + String keyColsVar = ocEnv.generateVar("keyCols"); + Env kcEnv = ocEnv.addBinding("keyCols", keyColsVar); + String okVar = kcEnv.generateVar("ok"); + Env resultEnv = kcEnv.addBinding("ok", okVar); + String setPriv = r(privateVar) + ":(SetPrivate " + b(leftCols) + " " + b(rightCols) + " " + b(outCols) + ") &\n" + " (Let\n" + " (" + r(keyColsVar) + " " + r(okVar) + "):(CanConvertUnionToDistinctUnionAll\n" + " " + r(leftCols) + "\n" + " " + r(rightCols) + "\n" + " )\n" + " " + r(okVar) + "\n" + " )"; + String pattern = N("Union", b(leftVar), b(rightVar), setPriv); + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchUnionPullUpConstants(Env env, RelRN.Union union, RelRN.Project leftProject, RelRN.Project rightProject) { + Env liEnv = onMatch(env, leftProject.source()); + String leftInputVar = liEnv.generateVar("leftInput"); + Env liBound = liEnv.addBinding("leftInput", leftInputVar); + String leftProjVar = liBound.generateVar("leftProjections"); + Env lpEnv = liBound.addBinding("leftProjections", leftProjVar); + String leftPassVar = lpEnv.generateVar("leftPassthrough"); + Env lpassEnv = lpEnv.addBinding("leftPassthrough", leftPassVar); + String leftVar = lpassEnv.generateVar("left"); + Env lBound = lpassEnv.addBinding("left", leftVar); + Env riEnv = onMatch(lBound, rightProject.source()); + String rightInputVar = riEnv.generateVar("rightInput"); + Env riBound = riEnv.addBinding("rightInput", rightInputVar); + String rightProjVar = riBound.generateVar("rightProjections"); + Env rpEnv = riBound.addBinding("rightProjections", rightProjVar); + String rightPassVar = rpEnv.generateVar("rightPassthrough"); + Env rpassEnv = rpEnv.addBinding("rightPassthrough", rightPassVar); + String rightVar = rpassEnv.generateVar("right"); + Env rBound = rpassEnv.addBinding("right", rightVar); + String privateVar = rBound.generateVar("private"); + Env privEnv = rBound.addBinding("private", privateVar); + String leftCols = privEnv.generateVar("leftCols"); + Env lcEnv = privEnv.addBinding("leftCols", leftCols); + String rightCols = lcEnv.generateVar("rightCols"); + Env rcEnv = lcEnv.addBinding("rightCols", rightCols); + String outCols = rcEnv.generateVar("outCols"); + Env resultEnv = rcEnv.addBinding("outCols", outCols); + String uType = union.all() ? "UnionAll" : "Union"; + String setPriv = r(privateVar) + ":(SetPrivate " + b(leftCols) + " " + b(rightCols) + " " + b(outCols) + ") &\n" + " (HasMatchingConstantsFromUnion\n" + " " + r(leftProjVar) + "\n" + " " + r(rightProjVar) + "\n" + " " + r(leftCols) + "\n" + " " + r(rightCols) + "\n" + " " + r(outCols) + "\n" + " )"; + String pattern = "(" + uType + "\n" + " " + r(leftVar) + ":(Project\n" + " " + b(leftInputVar) + "\n" + " " + b(leftProjVar) + "\n" + " " + b(leftPassVar) + "\n" + " )\n" + " " + r(rightVar) + ":(Project\n" + " " + b(rightInputVar) + "\n" + " " + b(rightProjVar) + "\n" + " " + b(rightPassVar) + "\n" + " )\n" + " " + setPriv + "\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchUnionGeneral(Env env, RelRN.Union union) { + Env currentEnv = env; + Seq sourcePatterns = Seq.empty(); + for (RelRN source : union.sources()) { + Env sourceEnv = onMatch(currentEnv, source); + if (source instanceof RelRN.Union) { + String subPrivate = sourceEnv.bindings().get("union_private"); + if (subPrivate != null) { + sourceEnv = sourceEnv.addBinding("inner_union_private", subPrivate); + } + } + sourcePatterns = sourcePatterns.appended(sourceEnv.current()); + currentEnv = sourceEnv; + } + String privateVar = currentEnv.generateVar("private"); + Env resultEnv = currentEnv.addBinding("union_private", privateVar); + String uType = union.all() ? "UnionAll" : "Union"; + String pattern = sourcePatterns.size() == 2 + ? N(uType, sourcePatterns.get(0), sourcePatterns.get(1), b(privateVar)) + : buildNestedUnion(uType, sourcePatterns, privateVar + ":*"); + return resultEnv.setPattern(pattern).focus(pattern); + } + private String buildNestedUnion(String uType, Seq sources, String privatePattern) { + if (sources.size() == 2) { + return N(uType, sources.get(0), sources.get(1), "$" + privatePattern); + } + return N(uType, sources.get(0), buildNestedUnion(uType, sources.drop(1), privatePattern), "$" + privatePattern ); + } + @Override + public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { + if (intersect.sources().size() == 2 && intersect.sources().get(0) instanceof RelRN.Intersect inner && inner.sources().size() == 2) { + return matchNestedIntersect(env, intersect); + } + if (intersect.sources().size() == 2 && intersect.sources().get(1) instanceof RelRN.Empty) { + String leftVar = env.generateVar("left"); + String rightVar = env.generateVar("right"); + String iType = intersect.all() ? "IntersectAll" : "Intersect"; + String pattern = "(" + iType + "\n" + " " + b(leftVar) + "\n" + " " + b(rightVar) + " & (HasZeroRows " + r(rightVar) + ")\n" + ")"; + return env.addBinding("isPruneEmptyIntersect", "true") .addBinding("pruneEmptyLeft", leftVar) .setPattern(pattern).focus(pattern); + } + Env currentEnv = env; + Seq sourcePatterns = Seq.empty(); + for (RelRN source : intersect.sources()) { + Env sourceEnv = onMatch(currentEnv, source); + if (source instanceof RelRN.Intersect) { + String subPrivate = sourceEnv.bindings().get("intersect_private"); + if (subPrivate != null) { + sourceEnv = sourceEnv.addBinding("inner_intersect_private", subPrivate); + } + } + sourcePatterns = sourcePatterns.appended(sourceEnv.current()); + currentEnv = sourceEnv; + } + String privateVar = currentEnv.generateVar("private"); + Env resultEnv = currentEnv.addBinding("intersect_private", privateVar); + String iType = intersect.all() ? "IntersectAll" : "Intersect"; + String pattern = sourcePatterns.size() == 2 + ? N(iType, sourcePatterns.get(0), sourcePatterns.get(1), b(privateVar)) + : buildNestedIntersect(iType, sourcePatterns, privateVar + ":*"); + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchNestedIntersect(Env env, RelRN.Intersect intersect) { + String leftLeftVar = env.generateVar("leftLeft"); + Env llEnv = env.addBinding("leftLeft", leftLeftVar); + String leftRightVar = llEnv.generateVar("leftRight"); + Env lrEnv = llEnv.addBinding("leftRight", leftRightVar); + String innerPrivate = lrEnv.generateVar("innerPrivate"); + Env ipEnv = lrEnv.addBinding("innerPrivate", innerPrivate); + String innerLeftCols = ipEnv.generateVar("innerLeftCols"); + Env ilcEnv = ipEnv.addBinding("innerLeftCols", innerLeftCols); + String innerRightCols = ilcEnv.generateVar("innerRightCols"); + Env ircEnv = ilcEnv.addBinding("innerRightCols", innerRightCols); + String leftVar = ircEnv.generateVar("left"); + Env lEnv = ircEnv.addBinding("left", leftVar); + String rightVar = lEnv.generateVar("right"); + Env rEnv = lEnv.addBinding("right", rightVar); + String outerPrivate = rEnv.generateVar("outerPrivate"); + Env opEnv = rEnv.addBinding("outerPrivate", outerPrivate); + String outerRightCols = opEnv.generateVar("outerRightCols"); + Env orcEnv = opEnv.addBinding("outerRightCols", outerRightCols); + String outerOutCols = orcEnv.generateVar("outerOutCols"); + Env resultEnv = orcEnv .addBinding("outerOutCols", outerOutCols) .addBinding("isIntersectMerge", "true"); + String iType = intersect.all() ? "IntersectAll" : "Intersect"; + String innerSetPriv = r(innerPrivate) + ":(SetPrivate " + b(innerLeftCols) + " " + b(innerRightCols) + " *)"; + String outerSetPriv = r(outerPrivate) + ":(SetPrivate * " + b(outerRightCols) + " " + b(outerOutCols) + ")"; + String pattern = "(" + iType + "\n" + " " + r(leftVar) + ":(" + iType + "\n" + " " + b(leftLeftVar) + "\n" + " " + b(leftRightVar) + "\n" + " " + innerSetPriv + "\n" + " )\n" + " " + b(rightVar) + "\n" + " " + outerSetPriv + "\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private String buildNestedIntersect(String iType, Seq sources, String privatePattern) { + if (sources.size() == 2) { + return N(iType, sources.get(0), sources.get(1), "$" + privatePattern); + } + return N(iType, sources.get(0), buildNestedIntersect(iType, sources.drop(1), privatePattern), "$" + privatePattern ); + } + @Override + public Env onMatchMinus(Env env, RelRN.Minus minus) { + if (env.rulename.equals("MinusMerge") && minus.sources().size() == 2 && minus.sources().get(0) instanceof RelRN.Minus) { + return matchMinusMerge(env, minus); + } + if (minus.sources().size() == 2 && minus.sources().get(0) instanceof RelRN.Empty) { + String leftVar = env.generateVar("left"); + String rightVar = "right"; + String pattern = "(Except\n" + " " + b(leftVar) + " & (HasZeroRows " + r(leftVar) + ")\n" + " " + r(rightVar) + ":*\n" + ")"; + return env.addBinding("isPruneEmptyMinus", "true") .addBinding("pruneEmptyLeft", leftVar) .addBinding("right", rightVar) .setPattern(pattern).focus(pattern); + } + Env currentEnv = env; + Seq sourcePatterns = Seq.empty(); + for (RelRN source : minus.sources()) { + Env sourceEnv = onMatch(currentEnv, source); + sourcePatterns = sourcePatterns.appended(sourceEnv.current()); + currentEnv = sourceEnv; + } + String privateVar = currentEnv.generateVar("private"); + Env resultEnv = currentEnv.addBinding("minus_private", privateVar); + String pattern; + if (minus.sources().size() == 2) { + pattern = N("Except", sourcePatterns.get(0), sourcePatterns.get(1), b(privateVar)); + } else { + pattern = "(Except\n " + sourcePatterns.joinToString("\n ") + "\n " + b(privateVar) + "\n)"; + } + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchMinusMerge(Env env, RelRN.Minus minus) { + String leftVar = env.generateVar("left"); + Env lEnv = env.addBinding("left", leftVar); + String leftLeftVar = lEnv.generateVar("leftLeft"); + Env llEnv = lEnv.addBinding("leftLeft", leftLeftVar); + String leftRightVar = llEnv.generateVar("leftRight"); + Env lrEnv = llEnv.addBinding("leftRight", leftRightVar); + String innerPriv = lrEnv.generateVar("innerPrivate"); + Env ipEnv = lrEnv.addBinding("innerPrivate", innerPriv); + String rightVar = ipEnv.generateVar("right"); + Env rEnv = ipEnv.addBinding("right", rightVar); + String outerPriv = rEnv.generateVar("outerPrivate"); + Env resultEnv = rEnv .addBinding("outerPrivate", outerPriv) .addBinding("isMinusMerge", "true"); + String pattern = "(Except\n" + " " + r(leftVar) + ":(Except\n" + " " + b(leftLeftVar) + "\n" + " " + b(leftRightVar) + "\n" + " " + b(innerPriv) + "\n" + " )\n" + " " + b(rightVar) + "\n" + " " + b(outerPriv) + "\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + @Override + public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + if (aggregate.source() instanceof RelRN.Join topJoin && topJoin.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT && topJoin.left() instanceof RelRN.Join bottomJoin && bottomJoin.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT) { + return matchAggregateDoubleJoin(env); + } + if (aggregate.source() instanceof RelRN.Join j && j.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT) { + return matchAggregateLeftJoin(env, aggregate); + } + if (aggregate.source() instanceof RelRN.Project) { + return matchAggregateProjectSource(env, aggregate); + } + Env sourceEnv = onMatch(env, aggregate.source()); + String sourcePat = sourceEnv.current(); + Env aggsEnv = onMatchAggCalls(sourceEnv, aggregate.aggCalls()); + String aggsPat = aggsEnv.current(); + Env groupEnv = onMatchGroupSet(aggsEnv, aggregate.groupSet()); + String privateVar = groupEnv.generateVar("private"); + Env privEnv = groupEnv.addBinding("aggregate_private", privateVar); + String aggType = determineAggregateType(aggregate); + if (hasProjectionExpressionsInAggregate(aggregate)) { + String inputVar = privEnv.generateVar("input"); + Env inputEnv = privEnv.addBinding("input", inputVar); + String aggregationsVar = inputEnv.generateVar("aggregations"); + Env aggsBindEnv = inputEnv.addBinding("aggregations", aggregationsVar); + String groupingPriv = aggsBindEnv.generateVar("groupingPrivate"); + Env gpEnv = aggsBindEnv.addBinding("groupingPrivate", groupingPriv); + String condLine = b(aggregationsVar) + " & (CanExtractProjectFromAggregate " + r(aggregationsVar) + ")"; + String pattern = N(aggType, b(inputVar), condLine, b(groupingPriv)); + return gpEnv.addBinding("isAggregateExtractProject", "true") .setPattern(pattern).focus(pattern); + } + String filterBoundBy = ""; + if (aggregate.source() instanceof RelRN.Filter) { + String condVarName = get(sourceEnv, "filterCondVar", null); + if (condVarName != null) { + filterBoundBy = " & " + filtersBoundBy(condVarName, "(GroupingCols " + r(privateVar) + ")"); + } + } + String pattern = N(aggType, sourcePat, aggsPat, b(privateVar) + filterBoundBy); + return privEnv.setPattern(pattern).focus(pattern); + } + private Env matchAggregateDoubleJoin(Env env) { + String leftVar = env.generateVar("left"); + Env lEnv = env.addBinding("left", leftVar); + String middleVar = lEnv.generateVar("middle"); + Env mEnv = lEnv.addBinding("middle", middleVar); + String rightVar = mEnv.generateVar("right"); + Env rEnv = mEnv.addBinding("right", rightVar); + String rightFiltersVar = rEnv.generateVar("rightFilters"); + Env rfEnv = rEnv.addBinding("rightFilters", rightFiltersVar); + String aggregationsVar = rfEnv.generateVar("aggregations"); + Env aggsEnv = rfEnv.addBinding("aggregations", aggregationsVar); + String groupingPrivVar = aggsEnv.generateVar("groupingPrivate"); + Env gpEnv = aggsEnv.addBinding("groupingPrivate", groupingPrivVar); + String groupingColsVar = gpEnv.generateVar("groupingCols"); + Env gcEnv = gpEnv.addBinding("groupingCols", groupingColsVar); + String orderingVar = gcEnv.generateVar("ordering"); + Env resultEnv = gcEnv.addBinding("ordering", orderingVar); + String pattern = "(DistinctOn\n" + " (LeftJoin\n" + " (LeftJoin\n" + " " + b(leftVar) + "\n" + " " + b(middleVar) + "\n" + " *\n" + " )\n" + " " + b(rightVar) + "\n" + " " + b(rightFiltersVar) + "\n" + " )\n" + " " + r(aggregationsVar) + ":[]\n" + " " + r(groupingPrivVar) + ":(GroupingPrivate " + b(groupingColsVar) + " " + b(orderingVar) + ") &\n" + " (ColsAreEmpty\n" + " (IntersectionCols\n" + " (OutputCols " + r(middleVar) + ")\n" + " (UnionCols\n" + " (FilterOuterCols " + r(rightFiltersVar) + ")\n" + " " + r(groupingColsVar) + "\n" + " )\n" + " )\n" + " ) &\n" + " (OrderingCanProjectCols\n" + " " + r(orderingVar) + "\n" + " (UnionCols (OutputCols " + r(leftVar) + ") (OutputCols " + r(rightVar) + "))\n" + " )\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchAggregateLeftJoin(Env env, RelRN.Aggregate aggregate) { + String leftVar = env.generateVar("left"); + Env lEnv = env.addBinding("left", leftVar); + String aggsVar = lEnv.generateVar("aggregations"); + Env aggsEnv = lEnv.addBinding("aggregations", aggsVar); + String groupingPrivVar = aggsEnv.generateVar("groupingPrivate"); + Env gpEnv = aggsEnv.addBinding("groupingPrivate", groupingPrivVar); + String groupingColsVar = gpEnv.generateVar("groupingCols"); + Env gcEnv = gpEnv.addBinding("groupingCols", groupingColsVar); + String orderingVar = gcEnv.generateVar("ordering"); + Env ordEnv = gcEnv.addBinding("ordering", orderingVar); + String leftColsVar = ordEnv.generateVar("leftCols"); + Env resultEnv = ordEnv.addBinding("leftCols", leftColsVar); + String pattern = "(DistinctOn\n" + " (LeftJoin\n" + " " + b(leftVar) + "\n" + " *\n" + " *\n" + " )\n" + " " + r(aggsVar) + ":[]\n" + " " + r(groupingPrivVar) + ":(GroupingPrivate " + b(groupingColsVar) + " " + b(orderingVar) + ") &\n" + " (ColsAreSubset\n" + " " + r(groupingColsVar) + "\n" + " " + r(leftColsVar) + ":(OutputCols " + r(leftVar) + ")\n" + " ) &\n" + " (OrderingCanProjectCols\n" + " " + r(orderingVar) + "\n" + " " + r(leftColsVar) + "\n" + " )\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env matchAggregateProjectSource(Env env, RelRN.Aggregate aggregate) { + String aggType = determineAggregateType(aggregate); + if (env.rulename.equals("AggregateProjectConstantToDummyJoin")) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String aggsVar = inputEnv.generateVar("aggregations"); + Env aggsEnv = inputEnv.addBinding("aggregations", aggsVar); + String gpVar = aggsEnv.generateVar("groupingPrivate"); + Env resultEnv = aggsEnv.addBinding("groupingPrivate", gpVar); + String gpLine = b(gpVar) + " & (HasConstantGroupingCols " + r(inputVar) + " " + r(gpVar) + ")"; + String pattern = N(aggType, r(inputVar) + ":(Project * * *)", b(aggsVar), gpLine); + return resultEnv.setPattern(pattern).focus(pattern); + } + if (env.rulename.equals("AggregateProjectMerge")) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String sourceVar = inputEnv.generateVar("input"); + Env sBound = inputEnv.addBinding("source", sourceVar); + String aggsVar = sBound.generateVar("aggregations"); + Env aggsEnv = sBound.addBinding("aggregations", aggsVar); + String gpVar = aggsEnv.generateVar("groupingPrivate"); + Env resultEnv = aggsEnv.addBinding("groupingPrivate", gpVar); + String innerProject = "(Project\n" + " " + b(sourceVar) + "\n" + " *\n" + " *\n" + " )"; + String gpLine = b(gpVar) + " & (CanMergeProjectIntoAggregate " + r(inputVar) + " " + r(gpVar) + ")"; + String pattern = "(" + aggType + "\n" + " " + r(inputVar) + ":(" + innerProject.trim() + "\n" + " " + b(aggsVar) + "\n" + " " + gpLine + "\n" + ")"; + pattern = "(" + aggType + "\n" + " " + r(inputVar) + ":(Project\n" + " " + b(sourceVar) + "\n" + " *\n" + " *\n" + " )\n" + " " + b(aggsVar) + "\n" + " " + gpLine + "\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String projectionsVar = inputEnv.generateVar("projections"); + Env projEnv = inputEnv.addBinding("projections", projectionsVar); + String passthroughVar = projEnv.generateVar("passthrough"); + Env passEnv = projEnv.addBinding("passthrough", passthroughVar); + String aggsVar = passEnv.generateVar("aggregations"); + Env aggsEnv = passEnv.addBinding("aggregations", aggsVar); + String gpVar = aggsEnv.generateVar("groupingPrivate"); + Env resultEnv = aggsEnv.addBinding("groupingPrivate", gpVar); + String gpLine = b(gpVar) + " & (CanRemapGroupingColsThroughProject " + r(gpVar) + " " + r(projectionsVar) + " " + r(passthroughVar) + ")"; + String innerProject = "(Project\n" + " " + b(inputVar) + "\n" + " " + b(projectionsVar) + "\n" + " " + b(passthroughVar) + "\n" + " )"; + String pattern = "(" + aggType + "\n" + " " + innerProject + "\n" + " " + b(aggsVar) + "\n" + " " + gpLine + "\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + private Env onMatchAggCalls(Env env, Seq aggCalls) { + Env currentEnv = env; + Seq aggPatterns = Seq.empty(); + boolean hasProjOp = false; + for (RelRN.AggCall aggCall : aggCalls) { + if (aggCall.operands().size() == 1 && aggCall.operands().get(0) instanceof RexRN.Proj proj) { + String projVar = currentEnv.bindings().getOrDefault(proj.operator().getName(), null); + if (projVar != null) { + aggPatterns = aggPatterns.appended(b(projVar)); + hasProjOp = true; + continue; + } + } + String aggVar = currentEnv.generateVar("agg"); + currentEnv = currentEnv.addBinding(aggCall.name(), aggVar); + aggPatterns = aggPatterns.appended(b(aggVar)); + } + String pattern; + if (aggCalls.size() == 1 && hasProjOp) { + pattern = aggPatterns.get(0); + } else if (aggCalls.size() == 1) { + String aggVar = currentEnv.generateVar("aggregations"); + currentEnv = currentEnv.addBinding("aggregations", aggVar); + pattern = b(aggVar); + } else { + pattern = "[" + aggPatterns.joinToString(" ") + "]"; + } + return currentEnv.setPattern(pattern).focus(pattern); + } + private Env onMatchGroupSet(Env env, Seq groupSet) { + Env currentEnv = env; + Seq groupPatterns = Seq.empty(); + for (RexRN groupCol : groupSet) { + Env groupEnv = onMatch(currentEnv, groupCol); + groupPatterns = groupPatterns.appended(groupEnv.current()); + currentEnv = groupEnv; + } + String pattern = "[" + groupPatterns.joinToString(" ") + "]"; + return currentEnv.setPattern(pattern).focus(pattern); + } + private static String determineAggregateType(RelRN.Aggregate aggregate) { + return "GroupBy"; + } + private static boolean hasProjectionExpressionsInAggregate(RelRN.Aggregate aggregate) { + for (RexRN g : aggregate.groupSet()) { + if (g instanceof RexRN.Proj) return true; + } + for (RelRN.AggCall c : aggregate.aggCalls()) { + for (RexRN op : c.operands()) { + if (op instanceof RexRN.Proj) return true; + } + } + return false; + } + @Override + public Env onMatchEmpty(Env env, RelRN.Empty empty) { + String var = env.generateVar("empty"); + return env.addBinding("empty", var).focus(r(var) + ":(Values)"); + } + @Override + public Env onMatchField(Env env, RexRN.Field field) { + String var = env.generateVar("field"); + return env.addBinding("field_" + field.ordinal(), var).focus(b(var)); + } + @Override + public Env onMatchPred(Env env, RexRN.Pred pred) { + String var = env.generateVar("cond"); + return env.addBinding(pred.operator().getName(), var).focus(b(var)); + } + @Override + public Env onMatchProj(Env env, RexRN.Proj proj) { + String var = env.generateVar("proj"); + return env.addBinding(proj.operator().getName(), var).focus(b(var)); + } + public Env onMatchGroupBy(Env env, RexRN.GroupBy groupBy) { + if (groupBy.sources().size() == 1 && groupBy.sources().get(0) instanceof RexRN.Proj proj) { + String projVar = env.bindings().getOrDefault(proj.operator().getName(), null); + if (projVar != null) return env.focus(b(projVar)); + } + String var = env.generateVar("groupBy"); + return env.addBinding(groupBy.operator().getName(), var).focus(b(var)); + } + @Override + public Env onMatchAnd(Env env, RexRN.And and) { + Env currentEnv = env; + Seq operandPats = Seq.empty(); + for (RexRN op : and.sources()) { + Env opEnv = onMatch(currentEnv, op); + operandPats = operandPats.appended(opEnv.current()); + currentEnv = opEnv; + } + String pattern = buildNestedAnd(operandPats); + return currentEnv.setPattern(pattern).focus(pattern); + } + private static String buildNestedAnd(Seq operands) { + if (operands.isEmpty()) return "(And)"; + if (operands.size() == 1) return operands.get(0); + return "(And " + operands.get(0) + " " + buildNestedAnd(operands.drop(1)) + ")"; + } + @Override + public Env onMatchTrue(Env env, RexRN literal) { + String var = env.generateVar("true"); + return env.addBinding("true_" + System.identityHashCode(literal), var) .focus(r(var) + ":True") .setPattern(r(var) + ":True"); + } + @Override + public Env onMatchFalse(Env env, RexRN literal) { + return env.focus("(False)").setPattern("(False)"); + } + @Override + public Env onMatchCustom(Env env, RelRN custom) { + if (custom instanceof org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin .AggregateGroupingByConstants aggGrouping) { + if (aggGrouping.input() instanceof org.qed.RRuleInstances .AggregateProjectConstantToDummyJoin.ProjectWithConstantLiterals) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String aggsVar = inputEnv.generateVar("aggregations"); + Env aggsEnv = inputEnv.addBinding("aggregations", aggsVar); + String gpVar = aggsEnv.generateVar("groupingPrivate"); + Env resultEnv = aggsEnv.addBinding("groupingPrivate", gpVar); + String gpLine = b(gpVar) + " & (HasConstantGroupingCols " + r(inputVar) + " " + r(gpVar) + ")"; + String pattern = N("GroupBy", r(inputVar) + ":(Project * * *)", b(aggsVar), gpLine); + return resultEnv.setPattern(pattern).focus(pattern); + } + return onMatch(env, aggGrouping.input()); + } + if (custom instanceof org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin .ProjectWithConstantLiterals projectWithConstants) { + if (projectWithConstants.input() instanceof org.qed.RRuleInstances .AggregateProjectConstantToDummyJoin.SourceTable) { + return onMatchScan(env, new RelRN.Scan("Source", org.qed.RexRN.varType("Source_Type", true), false)); + } + return onMatch(env, projectWithConstants.input()); + } + if (custom instanceof org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin .SourceTable) { + return onMatchScan(env, new RelRN.Scan("Source", org.qed.RexRN.varType("Source_Type", true), false)); + } + if (custom instanceof org.qed.RRuleInstances.ProjectAggregateMerge .ProjectUsingSubsetOfAggregates pusa) { + if (pusa.input() instanceof org.qed.RRuleInstances.ProjectAggregateMerge .AggregateWithMultipleCalls amc) { + Env srcEnv = onMatch(env, amc.input()); + String aggInputVar = srcEnv.current().replace("$", "").replace(":*", "").trim(); + Env aiBound = srcEnv.addBinding("aggInput", aggInputVar); + String aggsVar = aiBound.generateVar("aggregations"); + Env aggsEnv = aiBound.addBinding("aggregations", aggsVar); + String gpVar = aggsEnv.generateVar("groupingPrivate"); + Env gpEnv = aggsEnv.addBinding("groupingPrivate", gpVar); + String inputVar = gpEnv.generateVar("input"); + Env inputEnv = gpEnv.addBinding("input", inputVar); + String projVar = inputEnv.generateVar("projections"); + Env projEnv = inputEnv.addBinding("projections", projVar); + String passVar = projEnv.generateVar("passthrough"); + Env passEnv = projEnv.addBinding("passthrough", passVar); + String neededVar = passEnv.generateVar("needed"); + Env resultEnv = passEnv.addBinding("needed", neededVar); + String passCond = b(passVar) + " &\n" + " (CanPruneAggCols\n" + " " + r(aggsVar) + "\n" + " " + r(neededVar) + ":(UnionCols\n" + " (ProjectionOuterCols " + r(projVar) + ")\n" + " " + r(passVar) + "\n" + " )\n" + " )"; + String pattern = "(Project\n" + " " + r(inputVar) + ":(GroupBy\n" + " " + b(aggInputVar) + "\n" + " " + b(aggsVar) + "\n" + " " + b(gpVar) + "\n" + " )\n" + " " + b(projVar) + "\n" + " " + passCond + "\n" + ")"; + return resultEnv.setPattern(pattern).focus(pattern); + } + return onMatch(env, pusa.input()); + } + if (custom instanceof org.qed.RRuleInstances.ProjectAggregateMerge .AggregateWithMultipleCalls amc) { + return onMatch(env, amc.input()); + } + if (custom instanceof org.qed.RRuleInstances.ProjectAggregateMerge.SourceTable) { + return onMatchScan(env, new RelRN.Scan("Source", org.qed.RexRN.varType("Source_Type", true), false)); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants .UnionWithConstantColumns uwcc) { + return onMatchUnion(env, new RelRN.Union(true, Seq.of(uwcc.left(), uwcc.right()))); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants .LeftProjectionWithConstants lp) { + return onMatchProject(env, new RelRN.Project(lp.input().field(0), lp.input())); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants .RightProjectionWithConstants rp) { + return onMatchProject(env, new RelRN.Project(rp.input().field(0), rp.input())); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants.SourceTable) { + return onMatchScan(env, new RelRN.Scan("Source", org.qed.RexRN.varType("Source_Type", true), false)); + } + if (custom instanceof org.qed.RRuleInstances.UnionToDistinct.DistinctUnion du) { + return onMatchUnion(env, new RelRN.Union(false, Seq.of(du.left(), du.right()))); + } + if (custom instanceof org.qed.RRuleInstances.UnionToDistinct.UnionAll ua) { + return onMatchUnion(env, new RelRN.Union(true, Seq.of(ua.left(), ua.right()))); + } + return unimplementedOnMatch(env, custom); + } + @Override + public Env onMatchCustom(Env env, RexRN custom) { + if (custom instanceof RexRN.GroupBy groupBy) return onMatchGroupBy(env, groupBy); + return unimplementedOnMatch(env, custom); + } + @Override + public Env transformScan(Env env, RelRN.Scan scan) { + String var = get(env, scan.name(), "input"); + String pattern = r(var); + return env.setPattern(pattern).focus(pattern); + } + @Override + public Env transformFilter(Env env, RelRN.Filter filter) { + if (flag(env, "isFilterSetOpTranspose") && filter.source() instanceof RelRN.Union && env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("colmap") && env.bindings().containsKey("filter") && env.bindings().containsKey("item")) { + return transformFilterSetOpTranspose(env); + } + if (env.rulename.equals("JoinExtractFilter") && filter.source() instanceof RelRN.Join && env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("on") && env.bindings().containsKey("private")) { + String pattern = N("ConstructJoinExtractFilterResult", r(get(env, "left")), r(get(env, "right")), r(get(env, "on")), r(get(env, "private")) ); + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "isPruneEmptyFilter")) { + String pattern = r(get(env, "pruneEmptyInput")); + return env.setPattern(pattern).focus(pattern); + } + if (filter.cond() instanceof RexRN.True) { + return transform(env, filter.source()); + } + if (filter.source() instanceof RelRN.Empty) { + return transform(env, filter.source()); + } + if (filter.cond() instanceof RexRN.False) { + Env srcEnv = transform(env, filter.source()); + String pattern = "(ConstructEmptyValues (OutputCols " + srcEnv.current() + "))"; + return srcEnv.setPattern(pattern).focus(pattern); + } + Env srcEnv = transform(env, filter.source()); + Env condEnv = transform(srcEnv, filter.cond()); + String condPat = condEnv.current(); + String filterPat; + if (condPat.startsWith("(ConcatFilters") || (condPat.startsWith("$") && !condPat.contains(" "))) { + filterPat = condPat; + } else { + filterPat = "[" + condPat + "]"; + } + String pattern = N("Select", srcEnv.current(), filterPat); + return condEnv.setPattern(pattern).focus(pattern); + } + private Env transformFilterSetOpTranspose(Env env) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String colmapVar = get(env, "colmap"); + String filterVar = get(env, "filter"); + String itemVar = get(env, "item"); + String leftSelect = "(Select\n" + " " + r(leftVar) + "\n" + " [ (FiltersItem (MapSetOpFilterLeft " + r(itemVar) + " " + r(colmapVar) + ")) ]\n" + " )"; + String rightSelect = "(Select\n" + " " + r(rightVar) + "\n" + " [ (FiltersItem (MapSetOpFilterRight " + r(itemVar) + " " + r(colmapVar) + ")) ]\n" + " )"; + String innerUnion = "(Union\n" + " " + leftSelect + "\n" + " " + rightSelect + "\n" + " " + r(colmapVar) + "\n" + " )"; + String pattern = N("Select", innerUnion, "(RemoveFiltersItem " + r(filterVar) + " " + r(itemVar) + ")" ); + return env.setPattern(pattern).focus(pattern); + } + @Override + public Env transformProject(Env env, RelRN.Project project) { + if (env.bindings().containsKey("input") && env.bindings().containsKey("cond") && env.bindings().containsKey("proj") && env.bindings().containsKey("passthrough")) { + String inputVar = get(env, "input"); + String condVar = get(env, "cond"); + String projVar = get(env, "proj"); + String passthroughVar = get(env, "passthrough"); + String innerSelect = "(Select\n" + " " + r(inputVar) + "\n" + " " + r(condVar) + "\n" + ")"; + String pattern = N("Project", innerSelect, r(projVar), r(passthroughVar)); + return env.setPattern(pattern).focus(pattern); + } + if (env.rulename.equals("ProjectMerge") && env.bindings().containsKey("innerPassthrough") && env.bindings().containsKey("passthrough")) { + String innerPassVar = get(env, "innerPassthrough"); + String outerPassVar = get(env, "passthrough"); + java.util.List vars = new java.util.ArrayList<>(); + java.util.regex.Matcher m = java.util.regex.Pattern .compile("\\$([a-zA-Z_][a-zA-Z0-9_]*):\\*").matcher(env.pattern()); + while (m.find()) vars.add(m.group(1)); + String input1Var = vars.size() > 0 ? vars.get(0) : "input_1"; + String proj2Var = vars.size() > 1 ? vars.get(1) : "proj_2"; + String proj0Var = vars.size() > 3 ? vars.get(3) : "proj_0"; + String pattern = "(Project\n" + " " + r(input1Var) + "\n" + " (MergeProjections\n" + " " + r(proj0Var) + "\n" + " " + r(proj2Var) + "\n" + " " + r(outerPassVar) + "\n" + " )\n" + " (DifferenceCols\n" + " " + r(innerPassVar) + "\n" + " (ProjectionCols " + r(proj2Var) + ")\n" + " )\n" + ")"; + return env.setPattern(pattern).focus(pattern); + } + Env srcEnv = transform(env, project.source()); + Env projEnv = transform(srcEnv, project.map()); + String passVar = get(projEnv, "passthrough", "passthrough"); + String pattern = N("Project", srcEnv.current(), projEnv.current(), r(passVar)); + return projEnv.setPattern(pattern).focus(pattern); + } + private static String findFirstProjVar(Env env) { + for (var e : env.bindings().asJava().entrySet()) { + String k = e.getKey(), v = e.getValue(); + if ((k.startsWith("proj") || k.equals("proj")) && !k.contains("innerPass") && !k.contains("pass")) { + try { + int n = Integer.parseInt(v.substring(v.lastIndexOf('_') + 1)); + if (n < 3) return v; + } catch (NumberFormatException ignored) {} + } + } + return "proj_0"; + } + private static String findSecondProjVar(Env env) { + for (var e : env.bindings().asJava().entrySet()) { + String k = e.getKey(), v = e.getValue(); + if ((k.startsWith("proj") || k.equals("proj")) && !k.contains("innerPass") && !k.contains("pass")) { + try { + int n = Integer.parseInt(v.substring(v.lastIndexOf('_') + 1)); + if (n >= 2) return v; + } catch (NumberFormatException ignored) {} + } + } + return "proj_2"; + } + private static String findInputVar(Env env) { + for (var e : env.bindings().asJava().entrySet()) { + String k = e.getKey(), v = e.getValue(); + if (k.equals("input") || k.startsWith("input_")) { + return v; + } + } + return "input_1"; + } + @Override + public Env transformJoin(Env env, RelRN.Join join) { + String jType = joinType(join.ty().semantics()); + String privateVar = get(env, "private_" + System.identityHashCode(join), get(env, "last_private", "private")); + if (flag(env, "joinReduceTrue")) { + Env leftEnv = transform(env, join.left()); + Env rightEnv = transform(leftEnv, join.right()); + String onVar = get(rightEnv, "on"); + String itemVar = get(rightEnv, "item"); + String pattern = N(jType, leftEnv.current(), rightEnv.current(), "(RemoveFiltersItem " + r(onVar) + " " + r(itemVar) + ")", r(privateVar) ); + return rightEnv.setPattern(pattern).focus(pattern); + } + if (flag(env, "joinReduceFalse")) { + Env leftEnv = transform(env, join.left()); + Env rightEnv = transform(leftEnv, join.right()); + String pattern = N(jType, leftEnv.current(), rightEnv.current(), "[ (FiltersItem (False)) ]", r(privateVar) ); + return rightEnv.setPattern(pattern).focus(pattern); + } + if (flag(env, "isJoinAddRedundantSemiJoin") && env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("filters") && env.bindings().containsKey("private")) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String filtersVar = get(env, "filters"); + String privVar = get(env, "private"); + String semiJoin = "(SemiJoin\n" + " " + r(leftVar) + "\n" + " " + r(rightVar) + "\n" + " " + r(filtersVar) + "\n" + " (EmptyJoinPrivate)\n" + " )"; + String pattern = N("InnerJoin", semiJoin, r(rightVar), r(filtersVar), r(privVar) ); + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "isJoinCommute")) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String onVar = get(env, "on"); + String privVar = get(env, "private"); + String innerJoin = "(InnerJoin\n" + " " + r(rightVar) + "\n" + " " + r(leftVar) + "\n" + " " + r(onVar) + "\n" + " (CommuteJoinFlags " + r(privVar) + ")\n" + " )"; + String swapCols = "(SwapJoinOutputColumns\n" + " (OutputCols " + r(leftVar) + ")\n" + " (OutputCols " + r(rightVar) + ")\n" + " )"; + String pattern = N("Project", innerJoin, swapCols, "(MakeEmptyColSet)"); + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "isJoinConditionPush") && join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.INNER && env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("on") && env.bindings().containsKey("private") && !flag(env, "joinReduceTrue") && !flag(env, "joinReduceFalse") && !flag(env, "isJoinCommute") && !flag(env, "isJoinExtractFilter")) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String onVar = get(env, "on"); + String privVar = get(env, "private"); + String leftSel = "(Select " + r(leftVar) + " (ExtractBoundConditions " + r(onVar) + " (OutputCols " + r(leftVar) + ")))"; + String rightSel = "(Select " + r(rightVar) + " (ExtractBoundConditions " + r(onVar) + " (OutputCols " + r(rightVar) + ")))"; + String unboundCond = "(ExtractUnboundConditions\n" + " (ExtractUnboundConditions " + r(onVar) + " (OutputCols " + r(leftVar) + "))\n" + " (OutputCols " + r(rightVar) + ")\n" + " )"; + String pattern = N("InnerJoin", leftSel, rightSel, unboundCond, r(privVar)); + return env.setPattern(pattern).focus(pattern); + } + Env leftEnv = transform(env, join.left()); + Env rightEnv = transform(leftEnv, join.right()); + Env condEnv = transform(rightEnv, join.cond()); + String pattern = N(jType, leftEnv.current(), rightEnv.current(), condEnv.current(), r(privateVar)); + return condEnv.setPattern(pattern).focus(pattern); + } + @Override + public Env transformJoinWithPushedConds(Env env, RelRN.JoinWithPushedConds join) { + if (flag(env, "isJoinConditionPush") && env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("on") && env.bindings().containsKey("private")) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String onVar = get(env, "on"); + String privVar = get(env, "private"); + String leftSel = "(Select " + r(leftVar) + " (ExtractBoundConditions " + r(onVar) + " (OutputCols " + r(leftVar) + ")))"; + String rightSel = "(Select " + r(rightVar) + " (ExtractBoundConditions " + r(onVar) + " (OutputCols " + r(rightVar) + ")))"; + String unboundCond = "(ExtractUnboundConditions\n" + " (ExtractUnboundConditions " + r(onVar) + " (OutputCols " + r(leftVar) + "))\n" + " (OutputCols " + r(rightVar) + ")\n" + " )"; + String pattern = N("InnerJoin", leftSel, rightSel, unboundCond, r(privVar)); + return env.setPattern(pattern).focus(pattern); + } + Env leftEnv = transform(env, join.left()); + Env rightEnv = transform(leftEnv, join.right()); + Env condEnv = transform(rightEnv, join.cond()); + String privVar = condEnv.generateVar("private"); + String jType = joinType(join.ty().semantics()); + String pattern = N(jType, leftEnv.current(), rightEnv.current(), condEnv.current(), r(privVar)); + return condEnv.setPattern(pattern).focus(pattern); + } + @Override + public Env transformUnion(Env env, RelRN.Union union) { + if (flag(env, "isFilterSetOpTranspose")) { + return transformFilterSetOpTranspose(env); + } + if (flag(env, "isUnionMerge") && env.bindings().containsKey("leftLeft") && env.bindings().containsKey("leftRight") && env.bindings().containsKey("right") && env.bindings().containsKey("innerLeftCols") && env.bindings().containsKey("innerRightCols") && env.bindings().containsKey("innerOutCols") && env.bindings().containsKey("outerRightCols") && env.bindings().containsKey("outerOutCols")) { + String llVar = get(env, "leftLeft"); + String lrVar = get(env, "leftRight"); + String rVar = get(env, "right"); + String ilcVar = get(env, "innerLeftCols"); + String ircVar = get(env, "innerRightCols"); + String iocVar = get(env, "innerOutCols"); + String orcVar = get(env, "outerRightCols"); + String oocVar = get(env, "outerOutCols"); + String uType = union.all() ? "UnionAll" : "Union"; + String inner = "(" + uType + "\n" + " " + r(lrVar) + "\n" + " " + r(rVar) + "\n" + " (MakeSetPrivate " + r(ircVar) + " " + r(orcVar) + " " + r(iocVar) + ")\n" + " )"; + String pattern = "(" + uType + "\n" + " " + r(llVar) + "\n" + " " + inner + "\n" + " (MakeSetPrivate " + r(ilcVar) + " " + r(iocVar) + " " + r(oocVar) + ")\n" + ")"; + return env.setPattern(pattern).focus(pattern); + } + if (env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("leftProjections") && env.bindings().containsKey("rightProjections") && env.bindings().containsKey("private") && env.bindings().containsKey("leftInput") && env.bindings().containsKey("rightInput") && env.bindings().containsKey("leftPassthrough") && env.bindings().containsKey("rightPassthrough")) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String leftProjectionsVar = get(env, "leftProjections"); + String rightProjectionsVar = get(env, "rightProjections"); + String privateVar = get(env, "private"); + String leftInputVar = get(env, "leftInput"); + String rightInputVar = get(env, "rightInput"); + String leftPassthroughVar = get(env, "leftPassthrough"); + String rightPassthroughVar = get(env, "rightPassthrough"); + String pattern = "(UnionPullUpConstantsReplace\n" + " " + r(leftVar) + "\n" + " " + r(rightVar) + "\n" + " " + r(leftProjectionsVar) + "\n" + " " + r(rightProjectionsVar) + "\n" + " " + r(privateVar) + "\n" + " " + r(leftInputVar) + "\n" + " " + r(rightInputVar) + "\n" + " " + r(leftPassthroughVar) + "\n" + " " + r(rightPassthroughVar) + "\n" + ")"; + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "hasZeroRows")) { + String leftVar = get(env, "zeroInput", "input"); + String pattern = "(ConstructEmptyValues (OutputCols " + r(leftVar) + "))"; + return env.setPattern(pattern).focus(pattern); + } + Env currentEnv = env; + Seq sourcePatterns = Seq.empty(); + for (RelRN source : union.sources()) { + Env srcEnv = transform(currentEnv, source); + sourcePatterns = sourcePatterns.appended(srcEnv.current()); + currentEnv = srcEnv; + } + String privateVar = get(currentEnv, "union_private", "private"); + String uType = union.all() ? "UnionAll" : "Union"; + String pattern; + if (sourcePatterns.size() == 2) { + pattern = N(uType, sourcePatterns.get(0), sourcePatterns.get(1), r(privateVar)); + } else { + String nestedPrivate = get(currentEnv, "inner_union_private", privateVar); + String nested = buildNestedUnionTransform(uType, sourcePatterns.drop(1), nestedPrivate); + pattern = N(uType, sourcePatterns.get(0), nested, r(privateVar)); + } + return currentEnv.setPattern(pattern).focus(pattern); + } + private String buildNestedUnionTransform(String uType, Seq sources, String privVar) { + if (sources.size() == 2) return N(uType, sources.get(0), sources.get(1), r(privVar)); + return N(uType, sources.get(0), buildNestedUnionTransform(uType, sources.drop(1), privVar), r(privVar)); + } + @Override + public Env transformIntersect(Env env, RelRN.Intersect intersect) { + if (flag(env, "isIntersectMerge") && env.bindings().containsKey("leftLeft") && env.bindings().containsKey("leftRight") && env.bindings().containsKey("right") && env.bindings().containsKey("innerLeftCols") && env.bindings().containsKey("innerRightCols") && env.bindings().containsKey("outerRightCols") && env.bindings().containsKey("outerOutCols")) { + String llVar = get(env, "leftLeft"); + String lrVar = get(env, "leftRight"); + String rVar = get(env, "right"); + String ilcVar = get(env, "innerLeftCols"); + String ircVar = get(env, "innerRightCols"); + String orcVar = get(env, "outerRightCols"); + String oocVar = get(env, "outerOutCols"); + String iType = intersect.all() ? "IntersectAll" : "Intersect"; + String inner = "(" + iType + "\n" + " " + r(lrVar) + "\n" + " " + r(rVar) + "\n" + " (MakeSetPrivate " + r(ircVar) + " " + r(orcVar) + " " + r(ircVar) + ")\n" + " )"; + String pattern = N(iType, r(llVar), inner, "(MakeSetPrivate " + r(ilcVar) + " " + r(ircVar) + " " + r(oocVar) + ")" ); + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "isPruneEmptyIntersect")) { + String leftVar = get(env, "pruneEmptyLeft"); + String pattern = "(ConstructEmptyValues (OutputCols " + r(leftVar) + "))"; + return env.setPattern(pattern).focus(pattern); + } + Env currentEnv = env; + Seq sourcePatterns = Seq.empty(); + for (RelRN source : intersect.sources()) { + Env srcEnv = transform(currentEnv, source); + sourcePatterns = sourcePatterns.appended(srcEnv.current()); + currentEnv = srcEnv; + } + String privateVar = get(currentEnv, "intersect_private", "private"); + String iType = intersect.all() ? "IntersectAll" : "Intersect"; + String pattern; + if (sourcePatterns.size() == 2) { + pattern = N(iType, sourcePatterns.get(0), sourcePatterns.get(1), r(privateVar)); + } else { + String nestedPrivate = get(currentEnv, "inner_intersect_private", privateVar); + String nested = buildNestedIntersectTransform(iType, sourcePatterns.drop(1), nestedPrivate); + pattern = N(iType, sourcePatterns.get(0), nested, r(privateVar)); + } + return currentEnv.setPattern(pattern).focus(pattern); + } + private String buildNestedIntersectTransform(String iType, Seq sources, String privVar) { + if (sources.size() == 2) return N(iType, sources.get(0), sources.get(1), r(privVar)); + return N(iType, sources.get(0), buildNestedIntersectTransform(iType, sources.drop(1), privVar), r(privVar)); + } + @Override + public Env transformMinus(Env env, RelRN.Minus minus) { + if (flag(env, "isPruneEmptyMinus")) { + String leftVar = get(env, "pruneEmptyLeft"); + String pattern = "(ConstructEmptyValues (OutputCols " + r(leftVar) + "))"; + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "isMinusMerge") && env.bindings().containsKey("leftLeft") && env.bindings().containsKey("leftRight") && env.bindings().containsKey("right") && env.bindings().containsKey("innerPrivate") && env.bindings().containsKey("outerPrivate")) { + String llVar = get(env, "leftLeft"); + String lrVar = get(env, "leftRight"); + String rVar = get(env, "right"); + String ipVar = get(env, "innerPrivate"); + String opVar = get(env, "outerPrivate"); + String pattern = N("ConstructMinusMergeResult", r(llVar), r(lrVar), r(rVar), r(ipVar), r(opVar)); + return env.setPattern(pattern).focus(pattern); + } + String pattern = "(Except\n" + " $left\n" + " (Union\n" + " $rightB\n" + " $rightC\n" + " (MakeUnionPrivateForExcept $pInner $pOuter)\n" + " )\n" + " $pOuter\n" + ")"; + return env.setPattern(pattern).focus(pattern); + } + @Override + public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + String aggType = determineAggregateType(aggregate); + if (env.bindings().containsKey("left") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingPrivate") && !env.bindings().containsKey("topOn") && !env.bindings().containsKey("topPrivate")) { + String leftVar = get(env, "left"); + String aggsVar = get(env, "aggregations"); + String gpVar = get(env, "groupingPrivate"); + if (env.bindings().containsKey("rightFilters")) { + String rightVar = get(env, "right"); + String rfVar = get(env, "rightFilters"); + String lj = "(LeftJoin\n" + " " + r(leftVar) + "\n" + " " + r(rightVar) + "\n" + " " + r(rfVar) + "\n" + " (EmptyJoinPrivate)\n" + " )"; + String pattern = N("DistinctOn", lj, r(aggsVar), r(gpVar)); + return env.setPattern(pattern).focus(pattern); + } + String pattern = N("DistinctOn", r(leftVar), r(aggsVar), r(gpVar)); + return env.setPattern(pattern).focus(pattern); + } + if (env.rulename.equals("AggregateProjectConstantToDummyJoin") && env.bindings().containsKey("input") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingPrivate") && !env.bindings().containsKey("projectInput")) { + String inputVar = get(env, "input"); + String aggsVar = get(env, "aggregations"); + String gpVar = get(env, "groupingPrivate"); + String pattern = N("ConstructAggregateProjectConstantToDummyJoin", r(inputVar), r(aggsVar), r(gpVar)); + return env.setPattern(pattern).focus(pattern); + } + if (env.rulename.equals("AggregateProjectMerge") && env.bindings().containsKey("input") && env.bindings().containsKey("source") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingPrivate") && !env.bindings().containsKey("projections")) { + String inputVar = get(env, "input"); + String srcVar = get(env, "source"); + String aggsVar = get(env, "aggregations"); + String gpVar = get(env, "groupingPrivate"); + String pattern = N(aggType, r(srcVar), "(MergeProjectIntoAggregate " + r(inputVar) + " " + r(aggsVar) + ")", r(gpVar) ); + return env.setPattern(pattern).focus(pattern); + } + if (env.bindings().containsKey("input") && env.bindings().containsKey("projections") && env.bindings().containsKey("passthrough") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingPrivate")) { + String inputVar = get(env, "input"); + String projVar = get(env, "projections"); + String passVar = get(env, "passthrough"); + String aggsVar = get(env, "aggregations"); + String gpVar = get(env, "groupingPrivate"); + String remapAggs = "(RemapAggregationsThroughProject " + r(aggsVar) + " " + r(projVar) + ")"; + String remapGp = "(RemapGroupingColsThroughProject " + r(gpVar) + " " + r(projVar) + " " + r(passVar) + ")"; + String pattern = N(aggType, r(inputVar), remapAggs, remapGp); + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "isAggregateExtractProject")) { + String inputVar = get(env, "input", "input"); + String aggsVar = get(env, "aggregations", "aggregations"); + String gpVar = get(env, "groupingPrivate", "groupingPrivate"); + String pattern = "(ConstructAggregateExtractProject\n" + " " + r(inputVar) + "\n" + " " + r(aggsVar) + "\n" + " " + r(gpVar) + "\n" + ")"; + return env.setPattern(pattern).focus(pattern); + } + Env srcEnv = transform(env, aggregate.source()); + Env groupEnv = transformGroupSet(srcEnv, aggregate.groupSet()); + Env aggsEnv = transformAggCalls(groupEnv, aggregate.aggCalls()); + String privVar = get(aggsEnv, "aggregate_private", "private"); + String pattern = N(aggType, srcEnv.current(), aggsEnv.current(), r(privVar)); + return aggsEnv.setPattern(pattern).focus(pattern); + } + private Env transformAggCalls(Env env, Seq aggCalls) { + Env currentEnv = env; + Seq aggPatterns = Seq.empty(); + for (RelRN.AggCall aggCall : aggCalls) { + String aggVar = get(currentEnv, aggCall.name(), "agg"); + aggPatterns = aggPatterns.appended(r(aggVar)); + currentEnv = currentEnv.focus(r(aggVar)); + } + String pattern; + if (aggCalls.size() == 1) { + String aggVar = get(currentEnv, "aggregations", "aggregations"); + pattern = r(aggVar); + } else { + pattern = "[" + aggPatterns.joinToString(" ") + "]"; + } + return currentEnv.setPattern(pattern).focus(pattern); + } + private Env transformGroupSet(Env env, Seq groupSet) { + Env currentEnv = env; + Seq groupPatterns = Seq.empty(); + for (RexRN g : groupSet) { + Env ge = transform(currentEnv, g); + groupPatterns = groupPatterns.appended(ge.current()); + currentEnv = ge; + } + String pattern = "[" + groupPatterns.joinToString(" ") + "]"; + return currentEnv.setPattern(pattern).focus(pattern); + } + @Override + public Env transformEmpty(Env env, RelRN.Empty empty) { + if (flag(env, "isPruneEmptyMinus")) { + String leftVar = get(env, "pruneEmptyLeft"); + return env.setPattern("(ConstructEmptyValues (OutputCols " + r(leftVar) + "))") .focus("(ConstructEmptyValues (OutputCols " + r(leftVar) + "))"); + } + if (flag(env, "hasZeroRows")) { + if (env.bindings().containsKey("projections") && env.bindings().containsKey("passthrough")) { + String projVar = get(env, "projections"); + String passVar = get(env, "passthrough"); + String pattern = "(ConstructEmptyValues (UnionCols (ProjectionCols " + r(projVar) + ") " + r(passVar) + "))"; + return env.setPattern(pattern).focus(pattern); + } + String inputVar = get(env, "zeroInput", "input"); + String matchPat = env.pattern(); + if (matchPat != null && matchPat.contains("Union") && matchPat.contains("$left")) { + String pattern = "(ConstructEmptyValues (OutputCols " + r(inputVar) + "))"; + return env.setPattern(pattern).focus(pattern); + } + String pattern = r(inputVar); + return env.setPattern(pattern).focus(pattern); + } + if (flag(env, "isPruneEmptyFilter")) { + String inputVar = get(env, "pruneEmptyInput"); + return env.setPattern(r(inputVar)).focus(r(inputVar)); + } + String pattern = "(ConstructEmptyValues (OutputCols $input_0))"; + return env.setPattern(pattern).focus(pattern); + } + @Override + public Env transformField(Env env, RexRN.Field field) { + String var = get(env, "field_" + field.ordinal(), "field"); + return env.setPattern(r(var)).focus(r(var)); + } + @Override + public Env transformPred(Env env, RexRN.Pred pred) { + String var = get(env, pred.operator().getName(), "cond"); + return env.setPattern(r(var)).focus(r(var)); + } + @Override + public Env transformProj(Env env, RexRN.Proj proj) { + String var = get(env, proj.operator().getName(), "proj"); + return env.setPattern(r(var)).focus(r(var)); + } + public Env transformGroupBy(Env env, RexRN.GroupBy groupBy) { + if (groupBy.sources().size() == 1 && groupBy.sources().get(0) instanceof RexRN.Proj proj) { + String projVar = env.bindings().get(proj.operator().getName()); + if (projVar != null) return env.setPattern(r(projVar)).focus(r(projVar)); + } + String var = get(env, groupBy.operator().getName(), "groupBy"); + return env.setPattern(r(var)).focus(r(var)); + } + @Override + public Env transformAnd(Env env, RexRN.And and) { + Env currentEnv = env; + Seq operandPats = Seq.empty(); + for (RexRN op : and.sources()) { + Env opEnv = transform(currentEnv, op); + operandPats = operandPats.appended(opEnv.current()); + currentEnv = opEnv; + } + String pattern = "(ConcatFilters " + operandPats.joinToString(" ") + ")"; + return currentEnv.setPattern(pattern).focus(pattern); + } + @Override + public Env transformTrue(Env env, RexRN literal) { + String var = get(env, "true_" + System.identityHashCode(literal), "true"); + return env.setPattern(r(var)).focus(r(var)); + } + @Override + public Env transformFalse(Env env, RexRN literal) { + return env.setPattern("(False)").focus("(False)"); + } + @Override + public Env transformCustom(Env env, RelRN custom) { + if (custom instanceof org.qed.RRuleInstances.JoinCommute.ProjectionRelRN proj) { + return transform(env, proj.source()); + } + if (custom instanceof org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin .AggregateGroupingByDummyFields aggGrouping) { + if (env.rulename.equals("AggregateProjectConstantToDummyJoin") && env.bindings().containsKey("input") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingPrivate") && !env.bindings().containsKey("projectInput")) { + String inputVar = get(env, "input"); + String aggsVar = get(env, "aggregations"); + String gpVar = get(env, "groupingPrivate"); + String pattern = N("ConstructAggregateProjectConstantToDummyJoin", r(inputVar), r(aggsVar), r(gpVar)); + return env.setPattern(pattern).focus(pattern); + } + return transform(env, aggGrouping.input()); + } + if (custom instanceof org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin .ProjectWithDummyFields pwd) { + return transform(env, pwd.input()); + } + if (custom instanceof org.qed.RRuleInstances.AggregateProjectConstantToDummyJoin .SourceTable) { + return transformScan(env, new RelRN.Scan("Source", org.qed.RexRN.varType("Source_Type", true), false)); + } + if (custom instanceof org.qed.RRuleInstances.ProjectAggregateMerge .ProjectOptimized projectOptimized) { + if (env.rulename.equals("ProjectAggregateMerge") && env.bindings().containsKey("aggInput") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingPrivate") && env.bindings().containsKey("projections") && env.bindings().containsKey("passthrough") && env.bindings().containsKey("needed")) { + if (projectOptimized.input() instanceof org.qed.RRuleInstances.ProjectAggregateMerge .AggregateWithUsedCallsOnly aggregateOptimized) { + Env aggInputEnv = transform(env, aggregateOptimized.input()); + String aggInputPat = aggInputEnv.current(); + String aggsVar = get(env, "aggregations"); + String gpVar = get(env, "groupingPrivate"); + String projVar = get(env, "projections"); + String passVar = get(env, "passthrough"); + String neededVar = get(env, "needed"); + String pattern = "(Project\n" + " (GroupBy\n" + " " + aggInputPat + "\n" + " (PruneAggCols " + r(aggsVar) + " " + r(neededVar) + ")\n" + " " + r(gpVar) + "\n" + " )\n" + " " + r(projVar) + "\n" + " " + r(passVar) + "\n" + ")"; + return aggInputEnv.setPattern(pattern).focus(pattern); + } + } + return transform(env, projectOptimized.input()); + } + if (custom instanceof org.qed.RRuleInstances.ProjectAggregateMerge .ProjectUsingSubsetOfAggregates pusa) { + if (env.bindings().containsKey("input") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingPrivate") && env.bindings().containsKey("projections") && env.bindings().containsKey("passthrough")) { + String inputVar = get(env, "input"); + String projVar = get(env, "projections"); + String passVar = get(env, "passthrough"); + String aggsVar = get(env, "aggregations"); + String gpVar = get(env, "groupingPrivate"); + String remapAggs = "(RemapAggregationsThroughProject " + r(aggsVar) + " " + r(projVar) + ")"; + String remapGp = "(RemapGroupingColsThroughProject " + r(gpVar) + " " + r(projVar) + " " + r(passVar) + ")"; + String pattern = N("GroupBy", r(inputVar), remapAggs, remapGp); + return env.setPattern(pattern).focus(pattern); + } + return transform(env, pusa.input()); + } + if (custom instanceof org.qed.RRuleInstances.ProjectAggregateMerge .AggregateWithMultipleCalls amc) { + return transform(env, amc.input()); + } + if (custom instanceof org.qed.RRuleInstances.ProjectAggregateMerge.SourceTable) { + return transformScan(env, new RelRN.Scan("Source", org.qed.RexRN.varType("Source_Type", true), false)); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants .TopProjectionWithConstants topProj) { + if (env.rulename.equals("UnionPullUpConstants") && env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("leftProjections") && env.bindings().containsKey("rightProjections") && env.bindings().containsKey("private") && env.bindings().containsKey("leftInput") && env.bindings().containsKey("rightInput") && env.bindings().containsKey("leftPassthrough") && env.bindings().containsKey("rightPassthrough")) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String leftProjVar = get(env, "leftProjections"); + String rightProjVar = get(env, "rightProjections"); + String privateVar = get(env, "private"); + String leftInputVar = get(env, "leftInput"); + String rightInputVar = get(env, "rightInput"); + String leftPassVar = get(env, "leftPassthrough"); + String rightPassVar = get(env, "rightPassthrough"); + String pattern = "(UnionPullUpConstantsReplace\n" + " " + r(leftVar) + "\n" + " " + r(rightVar) + "\n" + " " + r(leftProjVar) + "\n" + " " + r(rightProjVar) + "\n" + " " + r(privateVar) + "\n" + " " + r(leftInputVar) + "\n" + " " + r(rightInputVar) + "\n" + " " + r(leftPassVar) + "\n" + " " + r(rightPassVar) + "\n" + ")"; + return env.setPattern(pattern).focus(pattern); + } + return transform(env, topProj.input()); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants .UnionWithConstantColumns uwcc) { + return transformUnion(env, new RelRN.Union(true, Seq.of(uwcc.left(), uwcc.right()))); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants .LeftProjectionWithConstants lp) { + return transformProject(env, new RelRN.Project(lp.input().field(0), lp.input())); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants .RightProjectionWithConstants rp) { + return transformProject(env, new RelRN.Project(rp.input().field(0), rp.input())); + } + if (custom instanceof org.qed.RRuleInstances.UnionPullUpConstants.SourceTable) { + return transformScan(env, new RelRN.Scan("Source", org.qed.RexRN.varType("Source_Type", true), false)); + } + if (custom instanceof org.qed.RRuleInstances.UnionToDistinct.DistinctAggregate da) { + if (env.rulename.equals("UnionToDistinct") && env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("private") && env.bindings().containsKey("leftCols") && env.bindings().containsKey("rightCols") && env.bindings().containsKey("outCols") && env.bindings().containsKey("keyCols")) { + String leftVar = get(env, "left"); + String rightVar = get(env, "right"); + String privVar = get(env, "private"); + String lcVar = get(env, "leftCols"); + String rcVar = get(env, "rightCols"); + String ocVar = get(env, "outCols"); + String kcVar = get(env, "keyCols"); + String translateCols = "(TranslateColSet\n" + " (DifferenceCols (OutputCols " + r(leftVar) + ") " + r(kcVar) + ")\n" + " " + r(lcVar) + "\n" + " " + r(ocVar) + "\n" + " )"; + String makeAgg = "(MakeAggCols\n" + " ConstAgg\n" + " " + translateCols + "\n" + " )"; + String translateKeyCols = "(TranslateColSet " + r(kcVar) + " " + r(lcVar) + " " + r(ocVar) + ")"; + String makeGrouping = "(MakeGrouping\n" + " " + translateKeyCols + "\n" + " (EmptyOrdering)\n" + " )"; + String unionAllPat = "(UnionAll " + r(leftVar) + " " + r(rightVar) + " " + r(privVar) + ")"; + String pattern = N("DistinctOn", unionAllPat, makeAgg, makeGrouping); + return env.setPattern(pattern).focus(pattern); + } + return transform(env, da.input()); + } + if (custom instanceof org.qed.RRuleInstances.UnionToDistinct.UnionAll ua) { + Env leftEnv = transform(env, ua.left()); + Env rightEnv = transform(leftEnv, ua.right()); + return rightEnv; + } + return unimplementedTransform(env, custom); + } + @Override + public Env transformCustom(Env env, RexRN custom) { + if (custom instanceof RexRN.GroupBy groupBy) return transformGroupBy(env, groupBy); + return unimplementedTransform(env, custom); + } + @Override + public String translate(String name, Env onMatch, Env transform) { + String match = postProcessMatch(onMatch.pattern()); + String out = postProcessTransform(transform.pattern(), match); + return "[" + name + ", Normalize]\n" + match + "\n=>\n" + out + "\n"; + } + private static String postProcessMatch(String match) { + if (match == null) return ""; + if (match.contains("HasZeroRows")) { + match = normalizeVars(match, "projections", "passthrough", "filters"); + } + if (match.startsWith("(Union\n") || match.startsWith("(UnionAll\n")) { + String[] lines = match.split("\n"); + if (lines.length >= 3 && lines[1].contains(":(Values)") && lines[2].contains(":(Values)")) { + String leftVar = extractVar(lines[1]); + String rightVar = extractVar(lines[2]); + String uType = lines[0].startsWith("(UnionAll") ? "UnionAll" : "Union"; + match = "(" + uType + "\n" + " " + b(leftVar) + " & (HasZeroRows " + r(leftVar) + ")\n" + " " + b(rightVar) + " & (HasZeroRows " + r(rightVar) + ")\n" + ")"; + } + } + if (match.contains("HasZeroRows") && match.contains("SetPrivate")) { + match = match.replaceAll("\\s+\\$private_\\d+:\\*\\s*\\)", "\n)") .replaceAll("\\s+\\$private_\\d+:\\*\\)", ")"); + } + return match; + } + private static String postProcessTransform(String out, String match) { + if (out == null) return ""; + if (out.startsWith("(ConstructEmptyValues (OutputCols $")) { + int startIdx = "(ConstructEmptyValues (OutputCols $".length(); + String var = extractVarFromPos(out, startIdx); + if (var.equals("input")) { + String numbered = findFirstVar(match); + if (numbered != null) { + out = out.replace( "(ConstructEmptyValues (OutputCols $input)", "(ConstructEmptyValues (OutputCols $" + numbered + ")"); + } + } + } else if (out.equals("$input")) { + String numbered = findFirstVar(match); + if (numbered != null) out = "$" + numbered; + } + if (match.contains("HasZeroRows") && match.contains("SetPrivate") && match.contains("outCols") && out.contains("ConstructEmptyValues")) { + java.util.regex.Matcher m = java.util.regex.Pattern .compile("\\$outCols_(\\d+)").matcher(match); + if (m.find()) { + String outColsVar = "$outCols_" + m.group(1); + int si = out.indexOf("(ConstructEmptyValues (OutputCols $"); + if (si >= 0) { + int vi = si + "(ConstructEmptyValues (OutputCols $".length(); + int vend = vi + extractVarFromPos(out, vi).length(); + out = out.substring(0, si) + "(ConstructEmptyValues (ColListToSet " + outColsVar + "))" + out.substring(vend + 2); + } + } + } else if (match.contains("HasZeroRows") && match.contains("$left") && out.contains("ConstructEmptyValues")) { + int leftIdx = match.indexOf("$left"); + if (leftIdx >= 0) { + String leftVar = extractVarFromPos(match, leftIdx + 1); + out = out.replaceAll( "(OutputCols \\$)[a-zA-Z_][a-zA-Z0-9_]*", "$1" + leftVar); + } + } + java.util.Map varMap = extractNumberedVarMap(match); + if (!varMap.isEmpty()) { + var sorted = new java.util.ArrayList<>(varMap.entrySet()); + sorted.sort((a, bx) -> Integer.compare(bx.getKey().length(), a.getKey().length())); + for (var e : sorted) { + out = out.replaceAll( "\\$" + java.util.regex.Pattern.quote(e.getKey()) + "(?![A-Za-z0-9_])", java.util.regex.Matcher.quoteReplacement(e.getValue())); + } + } + if (match.contains("HasZeroRows")) { + out = normalizeVars(out, "projections", "passthrough", "filters"); + } + return out; + } + private static String normalizeVars(String str, String... varNames) { + for (String v : varNames) { + if (str.contains("$" + v + "_")) { + str = str.replaceAll( "\\$" + v + "_\\d+", java.util.regex.Matcher.quoteReplacement("$" + v)); + } + } + return str; + } + private static String extractVar(String line) { + int i = line.indexOf('$'); + if (i < 0) return null; + return extractVarFromPos(line, i + 1); + } + private static String extractVarFromPos(String str, int pos) { + int end = pos; + while (end < str.length() && (Character.isLetterOrDigit(str.charAt(end)) || str.charAt(end) == '_')) { + end++; + } + return str.substring(pos, end); + } + private static String findFirstVar(String match) { + for (String line : match.split("\n")) { + if (line.contains("$private")) continue; + if (line.contains("$")) { + String v = extractVar(line); + if (v != null) return v; + } + } + return null; + } + private static java.util.Map extractNumberedVarMap(String match) { + java.util.Map map = new java.util.HashMap<>(); + java.util.regex.Matcher m = + java.util.regex.Pattern.compile("\\$([A-Za-z][A-Za-z0-9_]*)_([0-9]+)").matcher(match); + while (m.find()) { + String base = m.group(1); + String numbered = "$" + base + "_" + m.group(2); + map.putIfAbsent(base, numbered); + } + return map; + } + public record Env( AtomicInteger varId, String pattern, ImmutableMap bindings, String currentVar, String rulename ) { + public static Env empty(String rulename) { + return new Env(new AtomicInteger(), "", ImmutableMap.empty(), "", rulename); + } + public Env focus(String target) { + return new Env(varId, pattern, bindings, target, rulename); + } + public Env setPattern(String newPattern) { + return new Env(varId, newPattern, bindings, currentVar, rulename); + } + public Env addBinding(String key, String value) { + return new Env(varId, pattern, bindings.putted(key, value), currentVar, rulename); + } + public String generateVar(String prefix) { + return prefix + "_" + varId.getAndIncrement(); + } + public String current() { return currentVar; } + public String pattern() { return pattern; } + public ImmutableMap bindings() { return bindings; } + public String rulename() { return rulename; } + } +} diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java b/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java new file mode 100644 index 0000000..25b44fc --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java @@ -0,0 +1,187 @@ +package org.qed.Backends.Cockroach; + +import kala.tuple.Tuple; +import kala.collection.Seq; +import org.qed.*; +import org.apache.calcite.rel.rules.*; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgramBuilder; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +public class CockroachTester { + public static String genPath = + ProjectPaths.baseDir().resolve("src/main/java/org/qed/Backends/Cockroach/Generated").toString(); + public static String rulePath = ProjectPaths.baseDir().resolve("rules").toString(); + + // public static HepPlanner loadRules(java.util.List rules) { + // System.out.printf("Loading Rules: %s\n", + // rules.stream() + // .map(rule -> rule.getClass().getSimpleName()) + // .collect(java.util.stream.Collectors.joining(", "))); + + // var builder = new HepProgramBuilder(); + // for (var rule : rules) { + // builder.addRuleInstance(rule); + // } + // return new HepPlanner(builder.build()); + // } + + // public static HepPlanner loadRules(RelOptRule... rules) { + // return loadRules(java.util.Arrays.asList(rules)); + // } + + // public static HepPlanner loadRule(RelOptRule rule) { + // System.out.printf("Loading Rule: %s\n", rule.getClass().getSimpleName()); + // var builder = new HepProgramBuilder().addRuleInstance(rule); + // return new HepPlanner(builder.build()); + // } + + // public static HepPlanner loadRule(RelOptRule rule, int matchLimit) { + // System.out.printf("Loading Rule: %s (match limit: %d)\n", rule.getClass().getSimpleName(), matchLimit); + // var builder = new HepProgramBuilder() + // .addMatchLimit(matchLimit) + // .addRuleInstance(rule); + // return new HepPlanner(builder.build()); + // } + + public static Seq ruleList() { + java.io.File ruleDir = + ProjectPaths.baseDir().resolve("src/main/java/org/qed/RRuleInstances").toFile(); + java.io.File[] files = ruleDir.listFiles((dir, name) -> name.endsWith(".java")); + java.util.List rules = new java.util.ArrayList<>(); + if (files != null) { + for (java.io.File file : files) { + String className = file.getName().replace(".java", ""); + if ((className.contains("Distinct") && !className.contains("UnionToDistinct")) || + (className.contains("Pull") && !className.contains("UnionPullUpConstants")) || + className.contains("AggregativeJoinRemove")) { + continue; + } + + try { + Class clazz = Class.forName("org.qed.RRuleInstances." + className); + RRule rule = (RRule) clazz.getConstructor().newInstance(); + rules.add(rule); + } catch (Exception e) { + throw new RuntimeException("Failed to load rule: " + className, e); + } + } + } + return Seq.from(rules); + // var families = Seq.from(reflections.getSubTypesOf(RRule.RRuleFamily.class)) + // .filter(clazz -> !clazz.isInterface() && !Modifier.isAbstract(clazz.getModifiers())) + // .mapUnchecked(clazz -> { + // Constructor constructor = clazz.getDeclaredConstructor(); + // constructor.setAccessible(true); + // return constructor.newInstance(); + // }) + // .map(r -> (RRule.RRuleFamily) r); + // return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); + } + + // public static void verify() { + // ruleList().forEachUnchecked(rule -> rule.dump(rulePath + "/" + rule.name() + ".json")); + // } + + public static void generate() { + try { + java.io.File genDir = new java.io.File(genPath); + if (genDir.exists()) { + java.io.File[] files = genDir.listFiles((dir, name) -> name.matches(".*\\s+[0-9]+\\.opt$")); + if (files != null) { + for (java.io.File file : files) { + file.delete(); + } + } + } + } catch (Exception e) {} + var tester = new CockroachTester(); + ruleList().forEach(r -> tester.serialize(r, genPath)); + } + + // public static void runAllTests() { + // String packagePath = "src/main/java/org/qed/Backends/Cockroach/Tests"; + // java.io.File testDir = new java.io.File(packagePath); + // java.io.File[] testFiles = testDir.listFiles((dir, name) -> name.endsWith("Test.java")); + // if (testFiles != null) { + // for (java.io.File testFile : testFiles) { + // String className = "org.qed.Backends.Calcite.Tests." + testFile.getName().replace(".java", ""); + // try { + // Class testClass = Class.forName(className); + // testClass.getMethod("runTest").invoke(null); + // } catch (Exception e) { + // throw new RuntimeException("Failed to run test: " + className, e); + // } + // } + // } + // } + + public static void main(String[] args) throws IOException { + // var rule = new org.qed.RRuleInstances.AggregateExtractProject(); + // System.out.println(rule.explain()); + // Files.createDirectories(Path.of(rulePath)); + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // var rules = new RRuleInstance.JoinAssociate(); + // Files.createDirectories(Path.of(rulePath)); + // for (var rule : rules.family()) { + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // } + generate(); + // runAllTests(); + } + + public void serialize(RRule rule, String path) { + var generator = new CockroachGenerator(); + var code_gen = generator.generate(rule); + try { + Files.write(Path.of(path, rule.name() + ".opt"), code_gen.getBytes()); + } catch (IOException ioe) { + System.err.println(ioe.getMessage()); + } + } + + // public void test(RelOptRule rule, Seq tests) { + // System.out.println("Testing rule " + rule.getClass().getSimpleName()); + // var runner = loadRule(rule); + // var exams = tests.mapUnchecked(t -> Tuple.of(t, JSONDeserializer.load(new File(t)))); + // for (var entry : exams) { + // if (entry.getValue().size() != 2) { + // System.err.println(entry.getKey() + " does not have exactly two nodes, and thus is not a valid test"); + // continue; + // } + // verify(runner, entry.getValue().get(0), entry.getValue().get(1)); + // } + // } + + // public void verify(HepPlanner runner, RelNode source, RelNode target) { + // runner.setRoot(source); + // var answer = runner.findBestExp(); + + // String answerExplain = answer.explain(); + // String targetExplain = target.explain(); + + // if(answerExplain.equals(targetExplain)) { + // if(answerExplain.equals(source.explain())) + // { + // System.out.println("trivial"); + // System.out.println("> Given source RelNode:\n" + source.explain()); + // System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + // System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + // } + // else + // { + // System.out.println("succeeded"); + // } + // return; + // } + // System.out.println("failed"); + // System.out.println("> Given source RelNode:\n" + source.explain()); + // System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + // System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + // } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachTests b/src/main/java/org/qed/Backends/Cockroach/CockroachTests new file mode 100644 index 0000000..3ead500 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachTests @@ -0,0 +1,1579 @@ +exec-ddl +CREATE TABLE sales (id INT PRIMARY KEY, category1 STRING, category2 STRING, amount DECIMAL) +---- + +exec-ddl +CREATE TABLE emp (empno INT PRIMARY KEY, ename STRING, job STRING, mgr INT, hiredate DATE, sal DECIMAL, comm DECIMAL, deptno INT) +---- + +exec-ddl +CREATE TABLE dept (deptno INT PRIMARY KEY, dname STRING, loc STRING) +---- + +# -------------------------------------------------- +# FilterMerge +# -------------------------------------------------- +norm expect=FilterMerge disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT * +FROM ( + SELECT * + FROM sales + WHERE category1 = 'hardware' +) AS sales_filtered +WHERE category2 = 'home' +---- +select + ├── columns: id:1!null category1:2!null category2:3!null amount:4 + ├── key: (1) + ├── fd: ()-->(2,3), (1)-->(4) + ├── scan sales + │ ├── columns: id:1!null category1:2 category2:3 amount:4 + │ ├── key: (1) + │ └── fd: (1)-->(2-4) + └── filters + ├── category1:2 = 'hardware' [outer=(2), constraints=(/2: [/'hardware' - /'hardware']; tight), fd=()-->(2)] + └── category2:3 = 'home' [outer=(3), constraints=(/3: [/'home' - /'home']; tight), fd=()-->(3)] + +# -------------------------------------------------- +# FilterReduceTrue +# -------------------------------------------------- +norm expect=FilterReduceTrue disable=(ProjectFilterTranspose,JoinConditionPush,EliminateSelect,FilterProjectTranspose,JoinExtractFilter) +SELECT * +FROM sales +WHERE TRUE +---- +scan sales + ├── columns: id:1!null category1:2 category2:3 amount:4 + ├── key: (1) + └── fd: (1)-->(2-4) + +# -------------------------------------------------- +# FilterReduceFalse +# -------------------------------------------------- +norm expect=FilterReduceFalse disable=(ProjectFilterTranspose,JoinConditionPush,SimplifySelectFilters,FilterProjectTranspose,JoinExtractFilter) +SELECT * +FROM sales +WHERE FALSE +---- +values + ├── columns: id:1!null category1:2!null category2:3!null amount:4!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(1-4) + +# -------------------------------------------------- +# FilterIntoJoin +# -------------------------------------------------- +norm expect=FilterIntoJoin disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,PushSelectCondLeftIntoJoinLeftAndRight,PushSelectIntoJoinLeft,JoinConditionPush,FilterProjectTranspose,MergeSelectInnerJoin,JoinExtractFilter) +SELECT * +FROM emp +INNER JOIN dept ON emp.deptno = dept.deptno +WHERE emp.ename = dept.loc +---- +inner-join (hash) + ├── columns: empno:1!null ename:2!null job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null dname:12 loc:13!null + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8), (2)==(13), (13)==(2) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ └── fd: (1)-->(2-8) + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + ├── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + └── ename:2 = loc:13 [outer=(2,13), constraints=(/2: (/NULL - ]; /13: (/NULL - ]), fd=(2)==(13), (13)==(2)] + +# -------------------------------------------------- +# PruneEmptyFilter +# -------------------------------------------------- +norm expect=PruneEmptyFilter disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,SimplifySelectFilters,FilterReduceTrue,SimplifyZeroCardinalityGroup,SimplifyJoinFilters,EliminateProject,PruneScanCols,PruneJoinRightCols,JoinPushTransitivePredicates,JoinConditionPush,FilterProjectTranspose,FilterIntoJoin,JoinExtractFilter) +SELECT * +FROM (SELECT * FROM emp LIMIT 0) AS e +CROSS JOIN dept +WHERE TRUE +---- +values + ├── columns: empno:1!null ename:2!null job:3!null mgr:4!null hiredate:5!null sal:6!null comm:7!null deptno:8!null deptno:11!null dname:12!null loc:13!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(1-8,11-13) + +norm expect-not=PruneEmptyFilter disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,SimplifySelectFilters,FilterReduceTrue,SimplifyZeroCardinalityGroup,SimplifyJoinFilters,EliminateProject,PruneScanCols,PruneJoinRightCols,JoinPushTransitivePredicates,JoinConditionPush,FilterProjectTranspose,FilterIntoJoin,JoinExtractFilter) +SELECT * +FROM emp +WHERE deptno > 10 +---- +project + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null + ├── key: (1) + ├── fd: (1)-->(2-8) + └── select + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null crdb_internal_mvcc_timestamp:9 tableoid:10 + ├── key: (1) + ├── fd: (1)-->(2-10) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8 crdb_internal_mvcc_timestamp:9 tableoid:10 + │ ├── key: (1) + │ └── fd: (1)-->(2-10) + └── filters + └── deptno:8 > 10 [outer=(8), constraints=(/8: [/11 - ]; tight)] + +# -------------------------------------------------- +# PruneEmptyProject +# -------------------------------------------------- +norm expect=PruneEmptyProject disable=(ProjectFilterTranspose,JoinConditionPush,EliminateProject,FilterProjectTranspose,JoinExtractFilter) +SELECT e.empno +FROM (SELECT empno FROM emp LIMIT 0) AS e +---- +values + ├── columns: empno:1!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(1) + +norm expect=PruneEmptyProject disable=(ProjectFilterTranspose,JoinConditionPush,EliminateProject,FilterProjectTranspose,JoinExtractFilter) +SELECT empno + 1 AS next_id FROM (SELECT empno FROM emp WHERE FALSE) +---- +values + ├── columns: next_id:11!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(11) + +norm expect-not=PruneEmptyProject disable=(ProjectFilterTranspose,JoinConditionPush,EliminateProject,FilterProjectTranspose,JoinExtractFilter) +SELECT empno + 1 AS next_id FROM emp +---- +project + ├── columns: next_id:11!null + ├── immutable + ├── scan emp + │ ├── columns: empno:1!null + │ └── key: (1) + └── projections + └── empno:1 + 1 [as=next_id:11, outer=(1), immutable] + +# -------------------------------------------------- +# ProjectMerge +# -------------------------------------------------- +norm expect=ProjectMerge disable=(ProjectFilterTranspose,EliminateProject,PruneScanCols,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT id, name +FROM ( + SELECT empno AS id, ename AS name, empno + 1 AS next_id + FROM emp +) AS t +---- +project + ├── columns: id:1!null name:2 + ├── key: (1) + ├── fd: (1)-->(2) + └── scan emp + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8 crdb_internal_mvcc_timestamp:9 tableoid:10 + ├── key: (1) + └── fd: (1)-->(2-10) + +norm expect-not=ProjectMerge disable=(ProjectFilterTranspose,EliminateProject,PruneScanCols,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT next_id * 2 AS doubled +FROM (SELECT empno + 1 AS next_id FROM emp) +---- +project + ├── columns: doubled:12!null + ├── immutable + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8 crdb_internal_mvcc_timestamp:9 tableoid:10 + │ ├── key: (1) + │ └── fd: (1)-->(2-10) + └── projections + └── (empno:1 + 1) * 2 [as=doubled:12, outer=(1), immutable] + +# -------------------------------------------------- +# PruneEmptyMinus +# -------------------------------------------------- +norm expect=PruneEmptyMinus disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno +FROM (SELECT empno FROM emp WHERE FALSE) AS empties +EXCEPT +SELECT deptno FROM dept +---- +values + ├── columns: empno:1!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(1) + +norm expect-not=PruneEmptyMinus disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno FROM emp +EXCEPT +SELECT empno FROM emp WHERE FALSE +---- +scan emp + ├── columns: empno:1!null + └── key: (1) + +# -------------------------------------------------- +# JoinReduceTrue +# -------------------------------------------------- +norm expect=JoinReduceTrue disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,SimplifyJoinFilters,SimplifyAndTrue,SimplifyTrueAnd,SimplifySelectFilters,JoinConditionPush,FilterProjectTranspose,MergeSelectInnerJoin,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,JoinExtractFilter) +SELECT * +FROM emp +INNER JOIN ( + SELECT deptno FROM dept WHERE TRUE +) AS d ON TRUE +WHERE emp.deptno = d.deptno +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: (1)-->(2-8), (8)==(11), (11)==(8) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ └── fd: (1)-->(2-8) + ├── select + │ ├── columns: dept.deptno:11!null + │ ├── key: (11) + │ ├── scan dept + │ │ ├── columns: dept.deptno:11!null + │ │ └── key: (11) + │ └── filters + │ └── true + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +norm expect-not=JoinReduceTrue disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,SimplifyJoinFilters,SimplifyAndTrue,SimplifyTrueAnd,SimplifySelectFilters,JoinConditionPush,FilterProjectTranspose,MergeSelectInnerJoin,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,JoinExtractFilter) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ └── fd: (1)-->(2-8) + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +# -------------------------------------------------- +# FilterAggregateTranspose +# -------------------------------------------------- +norm expect=FilterAggregateTranspose disable=(ProjectFilterTranspose,PushSelectIntoGroupBy,AggregateFilterTranspose,EliminateProject,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT job, cnt +FROM ( + SELECT job, COUNT(*) as cnt + FROM emp + GROUP BY job +) AS grouped +WHERE job = 'Manager' +---- +group-by (streaming) + ├── columns: job:3!null cnt:11!null + ├── cardinality: [0 - 1] + ├── key: () + ├── fd: ()-->(3,11) + ├── select + │ ├── columns: job:3!null + │ ├── fd: ()-->(3) + │ ├── scan emp + │ │ └── columns: job:3 + │ └── filters + │ └── job:3 = 'Manager' [outer=(3), constraints=(/3: [/'Manager' - /'Manager']; tight), fd=()-->(3)] + └── aggregations + ├── count-rows [as=count_rows:11] + └── const-agg [as=job:3, outer=(3)] + └── job:3 + +norm expect-not=FilterAggregateTranspose disable=(ProjectFilterTranspose,PushSelectIntoGroupBy,AggregateFilterTranspose,EliminateProject,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT job, cnt +FROM ( + SELECT job, COUNT(*) as cnt + FROM emp + GROUP BY job +) AS grouped +WHERE cnt > 5 +---- +select + ├── columns: job:3 cnt:11!null + ├── key: (3) + ├── fd: (3)-->(11) + ├── group-by (hash) + │ ├── columns: job:3 count_rows:11!null + │ ├── grouping columns: job:3 + │ ├── key: (3) + │ ├── fd: (3)-->(11) + │ ├── scan emp + │ │ └── columns: job:3 + │ └── aggregations + │ └── count-rows [as=count_rows:11] + └── filters + └── count_rows:11 > 5 [outer=(11), constraints=(/11: [/6 - ]; tight)] + +# -------------------------------------------------- +# ProjectAggregateMerge +# -------------------------------------------------- +norm expect=ProjectAggregateMerge disable=(ProjectFilterTranspose,PruneAggCols,JoinConditionPush,FilterProjectTranspose,EliminateProject,AggregateProjectMerge,JoinExtractFilter) +SELECT job, sumsal +FROM ( + SELECT job, SUM(sal) AS sumsal, MIN(ename||'foo') AS minfoo + FROM emp + GROUP BY job +) AS grouped +---- +project + ├── columns: job:3 sumsal:11 + ├── key: (3) + ├── fd: (3)-->(11) + └── group-by (hash) + ├── columns: job:3 sum:11 + ├── grouping columns: job:3 + ├── key: (3) + ├── fd: (3)-->(11) + ├── scan emp + │ └── columns: job:3 sal:6 + └── aggregations + └── sum [as=sum:11, outer=(6)] + └── sal:6 + +norm expect-not=ProjectAggregateMerge disable=(ProjectFilterTranspose,PruneAggCols,JoinConditionPush,FilterProjectTranspose,EliminateProject,AggregateProjectMerge,JoinExtractFilter) +SELECT job, sumsal, minfoo +FROM ( + SELECT job, SUM(sal) AS sumsal, MIN(ename) AS minfoo + FROM emp + GROUP BY job +) AS grouped +---- +group-by (hash) + ├── columns: job:3 sumsal:11 minfoo:12 + ├── grouping columns: job:3 + ├── key: (3) + ├── fd: (3)-->(11,12) + ├── scan emp + │ └── columns: ename:2 job:3 sal:6 + └── aggregations + ├── sum [as=sum:11, outer=(6)] + │ └── sal:6 + └── min [as=min:12, outer=(2)] + └── ename:2 + +# -------------------------------------------------- +# AggregateFilterTranspose +# -------------------------------------------------- +norm expect=AggregateFilterTranspose disable=(ProjectFilterTranspose,PushSelectIntoGroupBy,EliminateProject,FilterAggregateTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT job, COUNT(*) as cnt +FROM ( + SELECT job, deptno + FROM emp + WHERE deptno > 10 +) AS filtered +GROUP BY job, deptno +---- +project + ├── columns: job:3 cnt:11!null + └── select + ├── columns: job:3 deptno:8!null count_rows:11!null + ├── key: (3,8) + ├── fd: (3,8)-->(11) + ├── group-by (hash) + │ ├── columns: job:3 deptno:8 count_rows:11!null + │ ├── grouping columns: job:3 deptno:8 + │ ├── key: (3,8) + │ ├── fd: (3,8)-->(11) + │ ├── scan emp + │ │ └── columns: job:3 deptno:8 + │ └── aggregations + │ └── count-rows [as=count_rows:11] + └── filters + └── deptno:8 > 10 [outer=(8), constraints=(/8: [/11 - ]; tight)] + +norm expect-not=AggregateFilterTranspose disable=(ProjectFilterTranspose,PushSelectIntoGroupBy,EliminateProject,FilterAggregateTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT job, COUNT(*) as cnt +FROM ( + SELECT job, sal + FROM emp + WHERE sal > 5000 +) AS filtered +GROUP BY job +---- +group-by (hash) + ├── columns: job:3 cnt:11!null + ├── grouping columns: job:3 + ├── immutable + ├── key: (3) + ├── fd: (3)-->(11) + ├── select + │ ├── columns: job:3 sal:6!null + │ ├── immutable + │ ├── scan emp + │ │ └── columns: job:3 sal:6 + │ └── filters + │ └── sal:6 > 5000 [outer=(6), immutable, constraints=(/6: (/5000 - ]; tight)] + └── aggregations + └── count-rows [as=count_rows:11] + +# -------------------------------------------------- +# SemiJoinFilterTranspose +# -------------------------------------------------- +norm expect=SemiJoinFilterTranspose disable=(ProjectFilterTranspose,PushSelectIntoJoinLeft,EliminateSelect,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT * +FROM emp +WHERE EXISTS (SELECT * FROM dept WHERE emp.deptno = dept.deptno) + AND emp.sal > 1000 +---- +semi-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8 + ├── immutable + ├── key: (1) + ├── fd: (1)-->(2-8) + ├── select + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 emp.deptno:8 + │ ├── immutable + │ ├── key: (1) + │ ├── fd: (1)-->(2-8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ └── filters + │ └── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + ├── scan dept + │ ├── columns: dept.deptno:11!null + │ └── key: (11) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +norm expect-not=SemiJoinFilterTranspose disable=(ProjectFilterTranspose,PushSelectIntoJoinLeft,EliminateSelect,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter,JoinPushTransitivePredicates,FilterIntoJoin,MergeSelectInnerJoin,PushSelectCondLeftIntoJoinLeftAndRight,JoinAddRedundantSemiJoin) +SELECT * +FROM emp +INNER JOIN dept ON emp.deptno = dept.deptno +WHERE emp.sal > 1000 +---- +select + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── immutable + ├── key: (1) + ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8) + ├── inner-join (hash) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8!null dept.deptno:11!null dname:12 loc:13 + │ ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + │ ├── key: (1) + │ ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ ├── scan dept + │ │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ │ ├── key: (11) + │ │ └── fd: (11)-->(12,13) + │ └── filters + │ └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + └── filters + └── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + +# -------------------------------------------------- +# JoinPushTransitivePredicates +# -------------------------------------------------- +norm expect=JoinPushTransitivePredicates disable=(ProjectFilterTranspose,PushFilterIntoJoinLeftAndRight,PushFilterIntoJoinLeft,PushFilterIntoJoinRight,MapFilterIntoJoinLeft,MapFilterIntoJoinRight,JoinConditionPush,FilterProjectTranspose,MergeSelectInnerJoin,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,FilterIntoJoin,JoinExtractFilter) +SELECT * +FROM emp +INNER JOIN dept ON emp.deptno = dept.deptno +WHERE emp.sal > 1000 +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── immutable + ├── key: (1) + ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8) + ├── semi-join (hash) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ ├── fd: (1)-->(2-8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ ├── scan dept + │ │ ├── columns: dept.deptno:11!null + │ │ └── key: (11) + │ └── filters + │ └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + ├── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + └── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + +norm expect-not=JoinPushTransitivePredicates disable=(ProjectFilterTranspose,PushFilterIntoJoinLeftAndRight,PushFilterIntoJoinLeft,PushFilterIntoJoinRight,MapFilterIntoJoinLeft,MapFilterIntoJoinRight,JoinConditionPush,FilterProjectTranspose,MergeSelectInnerJoin,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,JoinExtractFilter) +SELECT * FROM emp INNER JOIN dept ON TRUE WHERE emp.sal > 1000 +---- +inner-join (cross) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8 deptno:11!null dname:12 loc:13 + ├── immutable + ├── key: (1,11) + ├── fd: (1)-->(2-8), (11)-->(12,13) + ├── semi-join (cross) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ ├── fd: (1)-->(2-8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ ├── scan dept + │ └── filters (true) + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + └── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + +# -------------------------------------------------- +# ProjectFilterTranspose +# -------------------------------------------------- +norm expect=ProjectFilterTranspose disable=(FilterProjectTranspose,EliminateProject,JoinConditionPush,PushSelectIntoProject,PushSelectIntoInlinableProject,JoinExtractFilter) +SELECT empno AS e, deptno AS d +FROM emp +WHERE deptno > 10 +---- +select + ├── columns: e:1!null d:8!null + ├── key: (1) + ├── fd: (1)-->(8) + ├── project + │ ├── columns: empno:1!null deptno:8 + │ ├── key: (1) + │ ├── fd: (1)-->(8) + │ └── scan emp + │ ├── columns: empno:1!null deptno:8 + │ ├── key: (1) + │ └── fd: (1)-->(8) + └── filters + └── deptno:8 > 10 [outer=(8), constraints=(/8: [/11 - ]; tight)] + +norm expect-not=ProjectFilterTranspose disable=(FilterProjectTranspose,EliminateProject,JoinConditionPush,PushSelectIntoProject,PushSelectIntoInlinableProject,JoinExtractFilter) +SELECT empno +FROM emp +WHERE sal > 1000 +---- +project + ├── columns: empno:1!null + ├── immutable + ├── key: (1) + └── select + ├── columns: empno:1!null sal:6!null + ├── immutable + ├── key: (1) + ├── fd: (1)-->(6) + ├── scan emp + │ ├── columns: empno:1!null sal:6 + │ ├── key: (1) + │ └── fd: (1)-->(6) + └── filters + └── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + +# -------------------------------------------------- +# FilterProjectTranspose +# -------------------------------------------------- +norm expect=FilterProjectTranspose disable=(ProjectFilterTranspose,PushSelectIntoProject,EliminateProject,JoinConditionPush,JoinExtractFilter) +SELECT id, name +FROM ( + SELECT empno AS id, ename AS name + FROM emp +) AS projected +WHERE id > 100 +---- +project + ├── columns: id:1!null name:2 + ├── key: (1) + ├── fd: (1)-->(2) + └── select + ├── columns: empno:1!null ename:2 + ├── key: (1) + ├── fd: (1)-->(2) + ├── scan emp + │ ├── columns: empno:1!null ename:2 + │ ├── key: (1) + │ └── fd: (1)-->(2) + └── filters + └── empno:1 > 100 [outer=(1), constraints=(/1: [/101 - ]; tight)] + +norm expect-not=FilterProjectTranspose disable=(ProjectFilterTranspose,PushSelectIntoProject,PushSelectIntoInlinableProject,EliminateProject,JoinConditionPush,JoinExtractFilter) +SELECT id, id_plus_one +FROM ( + SELECT empno AS id, empno + 1 AS id_plus_one + FROM emp +) AS projected +WHERE id_plus_one > 100 +---- +select + ├── columns: id:1!null id_plus_one:11!null + ├── immutable + ├── key: (1) + ├── fd: (1)-->(11) + ├── project + │ ├── columns: id_plus_one:11!null empno:1!null + │ ├── immutable + │ ├── key: (1) + │ ├── fd: (1)-->(11) + │ ├── scan emp + │ │ ├── columns: empno:1!null + │ │ └── key: (1) + │ └── projections + │ └── empno:1 + 1 [as=id_plus_one:11, outer=(1), immutable] + └── filters + └── id_plus_one:11 > 100 [outer=(11), constraints=(/11: [/101 - ]; tight)] + +# -------------------------------------------------- +# PruneEmptyIntersect +# -------------------------------------------------- +norm expect=PruneEmptyIntersect disable=(ProjectFilterTranspose,SimplifyZeroCardinalityGroup,PruneEmptyFilter,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno +FROM emp +INTERSECT +SELECT empno FROM emp WHERE FALSE +---- +values + ├── columns: empno:1!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(1) + +norm expect-not=PruneEmptyIntersect disable=(ProjectFilterTranspose,SimplifyZeroCardinalityGroup,PruneEmptyFilter,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno FROM emp +INTERSECT +SELECT empno FROM emp +---- +intersect-all + ├── columns: empno:1!null + ├── left columns: empno:1!null + ├── right columns: empno:11 + ├── key: (1) + ├── scan emp + │ ├── columns: empno:1!null + │ └── key: (1) + └── scan emp + ├── columns: empno:11!null + └── key: (11) + +# -------------------------------------------------- +# JoinCommute +# -------------------------------------------------- +norm expect=JoinCommute disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,EliminateProject,PruneScanCols,ProjectMerge,JoinConditionPush,FilterProjectTranspose,MergeSelectInnerJoin,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,JoinExtractFilter) +SELECT * +FROM dept +INNER JOIN emp ON dept.deptno = emp.deptno +---- +project + ├── columns: deptno:1!null dname:2 loc:3 empno:6!null ename:7 job:8 mgr:9 hiredate:10 sal:11 comm:12 deptno:13!null + ├── key: (6) + ├── fd: (6)-->(1-3,7-13), (1)-->(2,3), (1)==(13), (13)==(1) + └── inner-join (hash) + ├── columns: dept.deptno:1!null dept.dname:2 dept.loc:3 dept.crdb_internal_mvcc_timestamp:4 dept.tableoid:5 emp.empno:6!null emp.ename:7 emp.job:8 emp.mgr:9 emp.hiredate:10 emp.sal:11 emp.comm:12 emp.deptno:13!null emp.crdb_internal_mvcc_timestamp:14 emp.tableoid:15 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (6) + ├── fd: (6)-->(7-15), (1)-->(2-5), (1)==(13), (13)==(1) + ├── scan emp + │ ├── columns: emp.empno:6!null emp.ename:7 emp.job:8 emp.mgr:9 emp.hiredate:10 emp.sal:11 emp.comm:12 emp.deptno:13 emp.crdb_internal_mvcc_timestamp:14 emp.tableoid:15 + │ ├── key: (6) + │ └── fd: (6)-->(7-15) + ├── scan dept + │ ├── columns: dept.deptno:1!null dept.dname:2 dept.loc:3 dept.crdb_internal_mvcc_timestamp:4 dept.tableoid:5 + │ ├── key: (1) + │ └── fd: (1)-->(2-5) + └── filters + └── dept.deptno:1 = emp.deptno:13 [outer=(1,13), constraints=(/1: (/NULL - ]; /13: (/NULL - ]), fd=(1)==(13), (13)==(1)] + +# -------------------------------------------------- +# JoinReduceFalse +# -------------------------------------------------- +norm expect=JoinReduceFalse disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,SimplifyJoinFilters,SimplifyAndFalse,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno AND FALSE +---- +values + ├── columns: empno:1!null ename:2!null job:3!null mgr:4!null hiredate:5!null sal:6!null comm:7!null deptno:8!null deptno:11!null dname:12!null loc:13!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(1-8,11-13) + +norm expect-not=JoinReduceFalse disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,SimplifyJoinFilters,SimplifyAndFalse,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ └── fd: (1)-->(2-8) + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +# -------------------------------------------------- +# UnionPullUpConstants +# -------------------------------------------------- +norm expect=UnionPullUpConstants disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,EliminateProject,JoinExtractFilter) +SELECT empno, 'ACTIVE' AS status, ename +FROM ( + SELECT empno, 'ACTIVE' AS status, ename FROM emp + UNION ALL + SELECT empno, 'ACTIVE' AS status, ename FROM emp +) AS unioned +---- +project + ├── columns: empno:23!null status:26!null ename:25 + ├── fd: ()-->(26) + ├── union-all + │ ├── columns: empno:23!null ename:25 + │ ├── left columns: emp.empno:1 emp.ename:2 + │ ├── right columns: emp.empno:12 emp.ename:13 + │ ├── scan emp + │ │ ├── columns: emp.empno:1!null emp.ename:2 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2) + │ └── scan emp + │ ├── columns: emp.empno:12!null emp.ename:13 + │ ├── key: (12) + │ └── fd: (12)-->(13) + └── projections + └── 'ACTIVE' [as=status:26] + +norm expect-not=UnionPullUpConstants disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,EliminateProject,JoinExtractFilter) +SELECT empno, 'ACTIVE' AS status, ename FROM emp +UNION +SELECT empno, 'ACTIVE' AS status, ename FROM emp +---- +union + ├── columns: empno:23!null status:24!null ename:25 + ├── left columns: emp.empno:1 status:11 emp.ename:2 + ├── right columns: emp.empno:12 status:22 emp.ename:13 + ├── key: (23-25) + ├── project + │ ├── columns: status:11!null emp.empno:1!null emp.ename:2 + │ ├── key: (1) + │ ├── fd: ()-->(11), (1)-->(2) + │ ├── scan emp + │ │ ├── columns: emp.empno:1!null emp.ename:2 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2) + │ └── projections + │ └── 'ACTIVE' [as=status:11] + └── project + ├── columns: status:22!null emp.empno:12!null emp.ename:13 + ├── key: (12) + ├── fd: ()-->(22), (12)-->(13) + ├── scan emp + │ ├── columns: emp.empno:12!null emp.ename:13 + │ ├── key: (12) + │ └── fd: (12)-->(13) + └── projections + └── 'ACTIVE' [as=status:22] + +# Matching constants at union column positions are required on both sides. +norm expect-not=UnionPullUpConstants disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,EliminateProject,JoinExtractFilter) +SELECT empno, 'A' AS status, ename FROM emp +UNION ALL +SELECT empno, 'B' AS status, ename FROM emp +---- +union-all + ├── columns: empno:23!null status:24!null ename:25 + ├── left columns: emp.empno:1 status:11 emp.ename:2 + ├── right columns: emp.empno:12 status:22 emp.ename:13 + ├── project + │ ├── columns: status:11!null emp.empno:1!null emp.ename:2 + │ ├── key: (1) + │ ├── fd: ()-->(11), (1)-->(2) + │ ├── scan emp + │ │ ├── columns: emp.empno:1!null emp.ename:2 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2) + │ └── projections + │ └── 'A' [as=status:11] + └── project + ├── columns: status:22!null emp.empno:12!null emp.ename:13 + ├── key: (12) + ├── fd: ()-->(22), (12)-->(13) + ├── scan emp + │ ├── columns: emp.empno:12!null emp.ename:13 + │ ├── key: (12) + │ └── fd: (12)-->(13) + └── projections + └── 'B' [as=status:22] + + +# -------------------------------------------------- +# PruneEmptyUnion +# -------------------------------------------------- +norm expect=PruneEmptyUnion disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT * FROM (VALUES (1)) AS t1(x) WHERE FALSE +UNION +SELECT * FROM (VALUES (1)) AS t2(x) WHERE FALSE +---- +values + ├── columns: x:3!null + ├── cardinality: [0 - 0] + ├── key: () + └── fd: ()-->(3) + +norm expect-not=PruneEmptyUnion disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT * FROM (VALUES (1)) AS t1(x) WHERE FALSE +UNION +SELECT * FROM (VALUES (2)) AS t2(x) +---- +values + ├── columns: x:3!null + ├── cardinality: [1 - 1] + ├── key: () + ├── fd: ()-->(3) + └── (2,) + +# -------------------------------------------------- +# JoinConditionPush +# -------------------------------------------------- +norm expect=JoinConditionPush disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,PushFilterIntoJoinLeft,PushFilterIntoJoinRight,FilterProjectTranspose,MergeSelectInnerJoin,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,JoinExtractFilter) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno AND emp.sal > 1000 AND dept.loc = 'NEW YORK' +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8!null deptno:11!null dname:12 loc:13!null + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── immutable + ├── key: (1) + ├── fd: ()-->(13), (1)-->(2-8), (11)-->(12), (8)==(11), (11)==(8) + ├── select + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 emp.deptno:8 + │ ├── immutable + │ ├── key: (1) + │ ├── fd: (1)-->(2-8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ └── filters + │ └── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + ├── select + │ ├── columns: dept.deptno:11!null dname:12 loc:13!null + │ ├── key: (11) + │ ├── fd: ()-->(13), (11)-->(12) + │ ├── scan dept + │ │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ │ ├── key: (11) + │ │ └── fd: (11)-->(12,13) + │ └── filters + │ └── loc:13 = 'NEW YORK' [outer=(13), constraints=(/13: [/'NEW YORK' - /'NEW YORK']; tight), fd=()-->(13)] + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +norm expect-not=JoinConditionPush disable=(JoinAddRedundantSemiJoin,ProjectFilterTranspose,PushFilterIntoJoinLeft,PushFilterIntoJoinRight,FilterProjectTranspose,MergeSelectInnerJoin,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,JoinExtractFilter) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ └── fd: (1)-->(2-8) + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +# -------------------------------------------------- +# IntersectMerge +# -------------------------------------------------- +norm expect=IntersectMerge disable=(ProjectFilterTranspose,SimplifyIntersectLeft,SimplifyIntersectRight,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno FROM emp +INTERSECT +SELECT empno FROM emp +INTERSECT +SELECT deptno FROM dept +---- +intersect + ├── columns: empno:1!null + ├── left columns: empno:1!null + ├── right columns: empno:11 + ├── key: (1) + ├── scan emp + │ ├── columns: empno:1!null + │ └── key: (1) + └── intersect + ├── columns: empno:11!null + ├── left columns: empno:11!null + ├── right columns: dept.deptno:21 + ├── key: (11) + ├── scan emp + │ ├── columns: empno:11!null + │ └── key: (11) + └── scan dept + ├── columns: dept.deptno:21!null + └── key: (21) + +# -------------------------------------------------- +# UnionMerge +# -------------------------------------------------- +norm expect=UnionMerge disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno FROM emp +UNION +SELECT empno FROM emp +UNION +SELECT deptno FROM dept +---- +union + ├── columns: empno:27!null + ├── left columns: emp.empno:1 + ├── right columns: empno:21 + ├── key: (27) + ├── scan emp + │ ├── columns: emp.empno:1!null + │ └── key: (1) + └── union + ├── columns: empno:21!null + ├── left columns: emp.empno:11 + ├── right columns: dept.deptno:22 + ├── key: (21) + ├── scan emp + │ ├── columns: emp.empno:11!null + │ └── key: (11) + └── scan dept + ├── columns: dept.deptno:22!null + └── key: (22) + +norm expect-not=UnionMerge disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno FROM emp +UNION +( + SELECT empno FROM emp + UNION + SELECT deptno FROM dept +) +---- +union + ├── columns: empno:27!null + ├── left columns: emp.empno:1 + ├── right columns: empno:26 + ├── key: (27) + ├── scan emp + │ ├── columns: emp.empno:1!null + │ └── key: (1) + └── union + ├── columns: empno:26!null + ├── left columns: emp.empno:11 + ├── right columns: dept.deptno:21 + ├── key: (26) + ├── scan emp + │ ├── columns: emp.empno:11!null + │ └── key: (11) + └── scan dept + ├── columns: dept.deptno:21!null + └── key: (21) + +# -------------------------------------------------- +# FilterSetOpTranspose +# -------------------------------------------------- +norm expect=FilterSetOpTranspose disable=(ProjectFilterTranspose,FilterProjectTranspose,EliminateProject,JoinExtractFilter) +SELECT * FROM (SELECT empno FROM emp UNION SELECT deptno FROM dept) WHERE empno > 10 +---- +union + ├── columns: empno:16!null + ├── left columns: emp.empno:1 + ├── right columns: dept.deptno:11 + ├── key: (16) + ├── project + │ ├── columns: emp.empno:1!null + │ ├── key: (1) + │ └── select + │ ├── columns: emp.empno:1!null + │ ├── key: (1) + │ ├── scan emp + │ │ ├── columns: emp.empno:1!null + │ │ └── key: (1) + │ └── filters + │ └── emp.empno:1 > 10 [outer=(1), constraints=(/1: [/11 - ]; tight)] + └── project + ├── columns: dept.deptno:11!null + ├── key: (11) + └── select + ├── columns: dept.deptno:11!null + ├── key: (11) + ├── scan dept + │ ├── columns: dept.deptno:11!null + │ └── key: (11) + └── filters + └── dept.deptno:11 > 10 [outer=(11), constraints=(/11: [/11 - ]; tight)] + +# -------------------------------------------------- +# UnionToDistinct +# -------------------------------------------------- +norm expect=UnionToDistinct disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,EliminateProject,JoinExtractFilter) +SELECT empno, ename FROM emp +UNION +SELECT empno, ename FROM emp +---- +distinct-on + ├── columns: empno:21!null ename:22 + ├── grouping columns: empno:21!null + ├── key: (21) + ├── fd: (21)-->(22) + ├── union-all + │ ├── columns: empno:21!null ename:22 + │ ├── left columns: emp.empno:1 emp.ename:2 + │ ├── right columns: emp.empno:11 emp.ename:12 + │ ├── project + │ │ ├── columns: emp.empno:1!null emp.ename:2 + │ │ ├── key: (1) + │ │ ├── fd: (1)-->(2) + │ │ └── scan emp + │ │ ├── columns: emp.empno:1!null emp.ename:2 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2) + │ └── project + │ ├── columns: emp.empno:11!null emp.ename:12 + │ ├── key: (11) + │ ├── fd: (11)-->(12) + │ └── scan emp + │ ├── columns: emp.empno:11!null emp.ename:12 + │ ├── key: (11) + │ └── fd: (11)-->(12) + └── aggregations + └── const-agg [as=ename:22, outer=(22)] + └── ename:22 + +norm expect-not=(UnionToDistinct,ConvertUnionToDistinctUnionAll) disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,EliminateProject,JoinExtractFilter) +SELECT empno FROM emp +UNION +SELECT empno FROM emp +---- +union + ├── columns: empno:21!null + ├── left columns: emp.empno:1 + ├── right columns: emp.empno:11 + ├── key: (21) + ├── project + │ ├── columns: emp.empno:1!null + │ ├── key: (1) + │ └── scan emp + │ ├── columns: emp.empno:1!null + │ └── key: (1) + └── project + ├── columns: emp.empno:11!null + ├── key: (11) + └── scan emp + ├── columns: emp.empno:11!null + └── key: (11) + +# Columns must map to the same base table on both sides. +norm expect-not=(UnionToDistinct,ConvertUnionToDistinctUnionAll) disable=(ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,EliminateProject,JoinExtractFilter) +SELECT empno FROM emp +UNION +SELECT deptno FROM dept +---- +union + ├── columns: empno:16!null + ├── left columns: emp.empno:1 + ├── right columns: dept.deptno:11 + ├── key: (16) + ├── project + │ ├── columns: emp.empno:1!null + │ ├── key: (1) + │ └── scan emp + │ ├── columns: emp.empno:1!null + │ └── key: (1) + └── project + ├── columns: dept.deptno:11!null + ├── key: (11) + └── scan dept + ├── columns: dept.deptno:11!null + └── key: (11) + +# -------------------------------------------------- +# AggregateJoinJoinRemove +# -------------------------------------------------- +norm expect=AggregateJoinJoinRemove disable=(ProjectFilterTranspose,EliminateJoinUnderGroupByLeft,EliminateJoinUnderGroupByRight,PruneScanCols,JoinConditionPush,FilterProjectTranspose,EliminateDistinct,EliminateGroupBy,EliminateProject,JoinExtractFilter) +SELECT e.empno, d.deptno +FROM emp e +LEFT JOIN dept d2 ON e.deptno = d2.deptno +LEFT JOIN dept d ON e.deptno = d.deptno +GROUP BY e.empno, d.deptno +---- +distinct-on + ├── columns: empno:1!null deptno:16 + ├── grouping columns: empno:1!null + ├── key: (1) + ├── fd: (1)-->(16) + ├── left-join (hash) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 d.deptno:16 d.dname:17 d.loc:18 d.crdb_internal_mvcc_timestamp:19 d.tableoid:20 + │ ├── multiplicity: left-rows(exactly-one), right-rows(zero-or-more) + │ ├── key: (1) + │ ├── fd: (1)-->(2-10,16-20), (16)-->(17-20) + │ ├── scan emp [as=e] + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-10) + │ ├── scan dept [as=d] + │ │ ├── columns: d.deptno:16!null d.dname:17 d.loc:18 d.crdb_internal_mvcc_timestamp:19 d.tableoid:20 + │ │ ├── key: (16) + │ │ └── fd: (16)-->(17-20) + │ └── filters + │ └── e.deptno:8 = d.deptno:16 [outer=(8,16), constraints=(/8: (/NULL - ]; /16: (/NULL - ]), fd=(8)==(16), (16)==(8)] + └── aggregations + └── const-agg [as=d.deptno:16, outer=(16)] + └── d.deptno:16 + +norm expect-not=AggregateJoinJoinRemove disable=(ProjectFilterTranspose,EliminateJoinUnderGroupByLeft,EliminateJoinUnderGroupByRight,PruneScanCols,JoinConditionPush,FilterProjectTranspose,EliminateDistinct,EliminateGroupBy,EliminateProject,JoinExtractFilter) +SELECT e.empno, d.deptno +FROM emp e +LEFT JOIN dept d2 ON e.empno = d2.deptno +LEFT JOIN dept d ON d2.loc = d.loc +GROUP BY e.empno, d.deptno +---- +distinct-on + ├── columns: empno:1!null deptno:16 + ├── grouping columns: empno:1!null d.deptno:16 + ├── key: (1,16) + └── left-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 d2.deptno:11 d2.dname:12 d2.loc:13 d2.crdb_internal_mvcc_timestamp:14 d2.tableoid:15 d.deptno:16 d.dname:17 d.loc:18 d.crdb_internal_mvcc_timestamp:19 d.tableoid:20 + ├── key: (1,16) + ├── fd: (1)-->(2-15), (11)-->(12-15), (16)-->(17-20) + ├── left-join (hash) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 d2.deptno:11 d2.dname:12 d2.loc:13 d2.crdb_internal_mvcc_timestamp:14 d2.tableoid:15 + │ ├── multiplicity: left-rows(exactly-one), right-rows(zero-or-one) + │ ├── key: (1) + │ ├── fd: (1)-->(2-15), (11)-->(12-15) + │ ├── scan emp [as=e] + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-10) + │ ├── scan dept [as=d2] + │ │ ├── columns: d2.deptno:11!null d2.dname:12 d2.loc:13 d2.crdb_internal_mvcc_timestamp:14 d2.tableoid:15 + │ │ ├── key: (11) + │ │ └── fd: (11)-->(12-15) + │ └── filters + │ └── empno:1 = d2.deptno:11 [outer=(1,11), constraints=(/1: (/NULL - ]; /11: (/NULL - ]), fd=(1)==(11), (11)==(1)] + ├── scan dept [as=d] + │ ├── columns: d.deptno:16!null d.dname:17 d.loc:18 d.crdb_internal_mvcc_timestamp:19 d.tableoid:20 + │ ├── key: (16) + │ └── fd: (16)-->(17-20) + └── filters + └── d2.loc:13 = d.loc:18 [outer=(13,18), constraints=(/13: (/NULL - ]; /18: (/NULL - ]), fd=(13)==(18), (18)==(13)] + +# -------------------------------------------------- +# AggregateJoinRemove +# -------------------------------------------------- +norm expect=AggregateJoinRemove disable=(ProjectFilterTranspose,EliminateJoinUnderGroupByLeft,EliminateJoinUnderGroupByRight,PruneScanCols,JoinConditionPush,FilterProjectTranspose,EliminateDistinct,EliminateGroupBy,EliminateProject,EliminateJoinUnderProjectLeft,JoinExtractFilter) +SELECT e.job +FROM emp e +LEFT JOIN dept d ON e.deptno = d.deptno +GROUP BY e.job +---- +distinct-on + ├── columns: job:3 + ├── grouping columns: job:3 + ├── key: (3) + └── scan emp [as=e] + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 + ├── key: (1) + └── fd: (1)-->(2-10) + +norm expect-not=AggregateJoinRemove disable=(ProjectFilterTranspose,EliminateJoinUnderGroupByLeft,EliminateJoinUnderGroupByRight,PruneScanCols,JoinConditionPush,FilterProjectTranspose,EliminateDistinct,EliminateGroupBy,EliminateProject,EliminateJoinUnderProjectLeft,JoinExtractFilter) +SELECT e.job, d.loc +FROM emp e +LEFT JOIN dept d ON e.deptno = d.deptno +GROUP BY e.job, d.loc +---- +distinct-on + ├── columns: job:3 loc:13 + ├── grouping columns: job:3 loc:13 + ├── key: (3,13) + └── left-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 d.deptno:11 dname:12 loc:13 d.crdb_internal_mvcc_timestamp:14 d.tableoid:15 + ├── multiplicity: left-rows(exactly-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: (1)-->(2-15), (11)-->(12-15) + ├── scan emp [as=e] + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 e.deptno:8 e.crdb_internal_mvcc_timestamp:9 e.tableoid:10 + │ ├── key: (1) + │ └── fd: (1)-->(2-10) + ├── scan dept [as=d] + │ ├── columns: d.deptno:11!null dname:12 loc:13 d.crdb_internal_mvcc_timestamp:14 d.tableoid:15 + │ ├── key: (11) + │ └── fd: (11)-->(12-15) + └── filters + └── e.deptno:8 = d.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +# -------------------------------------------------- +# AggregateProjectConstantToDummyJoin +# -------------------------------------------------- +norm expect=AggregateProjectConstantToDummyJoin disable=(JoinAddRedundantSemiJoin,MergeSelectInnerJoin,ProjectInnerJoinValues,InlineJoinConstantsLeft,InlineJoinConstantsRight,ProjectMerge,EliminateJoinNoColsRight,EliminateJoinNoColsLeft,PruneJoinRightCols,PruneJoinLeftCols,HoistJoinProjectLeft,HoistJoinProjectRight,JoinExtractFilter) +SELECT count(*) +FROM (SELECT empno, 1 AS k FROM emp) +GROUP BY k +---- +group-by (streaming) + ├── columns: count:12!null + ├── cardinality: [0 - 1] + ├── key: () + ├── fd: ()-->(12) + ├── scan emp + └── aggregations + └── count-rows [as=count_rows:12] + +norm expect-not=AggregateProjectConstantToDummyJoin disable=(JoinAddRedundantSemiJoin,MergeSelectInnerJoin,ProjectInnerJoinValues,InlineJoinConstantsLeft,InlineJoinConstantsRight,ProjectMerge,EliminateJoinNoColsRight,EliminateJoinNoColsLeft,PruneJoinRightCols,PruneJoinLeftCols,HoistJoinProjectLeft,HoistJoinProjectRight,JoinExtractFilter) +SELECT count(*) +FROM (SELECT empno, empno + 1 AS k FROM emp) +GROUP BY k +---- +project + ├── columns: count:12!null + ├── immutable + └── group-by (hash) + ├── columns: k:11!null count_rows:12!null + ├── grouping columns: k:11!null + ├── immutable + ├── key: (11) + ├── fd: (11)-->(12) + ├── project + │ ├── columns: k:11!null + │ ├── immutable + │ ├── scan emp + │ │ ├── columns: empno:1!null + │ │ └── key: (1) + │ └── projections + │ └── empno:1 + 1 [as=k:11, outer=(1), immutable] + └── aggregations + └── count-rows [as=count_rows:12] + +# -------------------------------------------------- +# AggregateProjectMerge +# -------------------------------------------------- +norm expect=AggregateProjectMerge disable=(AggregateExtractProject,MergeSelectInnerJoin,ProjectInnerJoinValues,InlineJoinConstantsLeft,InlineJoinConstantsRight,ProjectMerge,EliminateJoinNoColsRight,EliminateJoinNoColsLeft,PruneJoinRightCols,PruneJoinLeftCols,HoistJoinProjectLeft,HoistJoinProjectRight,JoinExtractFilter) +SELECT sum(k) +FROM (SELECT empno, empno + 1 AS k FROM emp) +GROUP BY empno +---- +project + ├── columns: sum:12!null + ├── immutable + └── group-by (hash) + ├── columns: empno:1!null sum:12!null + ├── grouping columns: empno:1!null + ├── immutable + ├── key: (1) + ├── fd: (1)-->(12) + ├── scan emp + │ ├── columns: empno:1!null + │ └── key: (1) + └── aggregations + └── sum [as=sum:12, outer=(1), immutable] + └── empno:1 + 1 + +norm expect-not=AggregateProjectMerge disable=(AggregateExtractProject,MergeSelectInnerJoin,ProjectInnerJoinValues,InlineJoinConstantsLeft,InlineJoinConstantsRight,ProjectMerge,EliminateJoinNoColsRight,EliminateJoinNoColsLeft,PruneJoinRightCols,PruneJoinLeftCols,HoistJoinProjectLeft,HoistJoinProjectRight,JoinExtractFilter) +SELECT sum(sal) +FROM (SELECT sal, empno + 1 AS k FROM emp) +GROUP BY k +---- +project + ├── columns: sum:12 + ├── immutable + └── group-by (hash) + ├── columns: k:11!null sum:12 + ├── grouping columns: k:11!null + ├── immutable + ├── key: (11) + ├── fd: (11)-->(12) + ├── project + │ ├── columns: k:11!null sal:6 + │ ├── immutable + │ ├── scan emp + │ │ ├── columns: empno:1!null sal:6 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(6) + │ └── projections + │ └── empno:1 + 1 [as=k:11, outer=(1), immutable] + └── aggregations + └── sum [as=sum:12, outer=(6)] + └── sal:6 + +# -------------------------------------------------- +# AggregateExtractProject +# -------------------------------------------------- +norm expect=AggregateExtractProject disable=(ProjectMerge,EliminateProject,PruneScanCols,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT sum(sal + 1), avg(sal + 1) +FROM emp +GROUP BY empno +---- +project + ├── columns: sum:12 avg:13 + ├── immutable + └── group-by (hash) + ├── columns: empno:1!null sum:12 avg:13 + ├── grouping columns: empno:1!null + ├── immutable + ├── key: (1) + ├── fd: (1)-->(12,13) + ├── project + │ ├── columns: column14:14 empno:1!null + │ ├── immutable + │ ├── key: (1) + │ ├── fd: (1)-->(14) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8 crdb_internal_mvcc_timestamp:9 tableoid:10 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-10) + │ └── projections + │ └── sal:6 + 1 [as=column14:14, outer=(6), immutable] + └── aggregations + ├── sum [as=sum:12, outer=(14)] + │ └── column14:14 + └── avg [as=avg:13, outer=(14)] + └── column14:14 + +# -------------------------------------------------- +# JoinExtractFilter +# -------------------------------------------------- +norm expect=JoinExtractFilter disable=(MergeSelectInnerJoin,PushFilterIntoJoinLeft,PushFilterIntoJoinRight,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,FilterIntoJoin,FilterProjectTranspose,JoinAddRedundantSemiJoin,JoinConditionPush,SemiJoinFilterTranspose,SimplifyJoinFilters,EliminateProject,PruneScanCols,PruneJoinLeftCols,PruneJoinRightCols,PruneSemiAntiJoinRightCols,ProjectFilterTranspose) +SELECT * +FROM emp +INNER JOIN dept ON emp.sal > 1000 +---- +project + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8 deptno:11!null dname:12 loc:13 + ├── immutable + ├── key: (1,11) + ├── fd: (1)-->(2-8), (11)-->(12,13) + └── select + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + ├── immutable + ├── key: (1,11) + ├── fd: (1)-->(2-10), (11)-->(12-15) + ├── inner-join (cross) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ ├── key: (1,11) + │ ├── fd: (1)-->(2-10), (11)-->(12-15) + │ ├── semi-join (cross) + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 + │ │ ├── key: (1) + │ │ ├── fd: (1)-->(2-10) + │ │ ├── scan emp + │ │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 + │ │ │ ├── key: (1) + │ │ │ └── fd: (1)-->(2-10) + │ │ ├── scan dept + │ │ │ ├── columns: dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ │ │ ├── key: (11) + │ │ │ └── fd: (11)-->(12-15) + │ │ └── filters (true) + │ ├── scan dept + │ │ ├── columns: dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ │ ├── key: (11) + │ │ └── fd: (11)-->(12-15) + │ └── filters (true) + └── filters + └── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + +norm expect=JoinExtractFilter disable=(MergeSelectInnerJoin,PushFilterIntoJoinLeft,PushFilterIntoJoinRight,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,FilterIntoJoin,FilterProjectTranspose,JoinAddRedundantSemiJoin,JoinConditionPush,SemiJoinFilterTranspose,SimplifyJoinFilters,EliminateProject,PruneScanCols,PruneJoinLeftCols,PruneJoinRightCols,PruneSemiAntiJoinRightCols,ProjectFilterTranspose) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno +---- +project + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── key: (1) + ├── fd: (1)-->(2-8,11-13), (11)-->(12,13), (8)==(11), (11)==(8) + └── select + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8!null emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + ├── key: (1) + ├── fd: (1)-->(2-10), (11)-->(12-15), (8)==(11), (11)==(8) + ├── inner-join (cross) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ ├── key: (1,11) + │ ├── fd: (1)-->(2-10), (11)-->(12-15) + │ ├── semi-join (cross) + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 + │ │ ├── key: (1) + │ │ ├── fd: (1)-->(2-10) + │ │ ├── scan emp + │ │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 + │ │ │ ├── key: (1) + │ │ │ └── fd: (1)-->(2-10) + │ │ ├── scan dept + │ │ │ ├── columns: dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ │ │ ├── key: (11) + │ │ │ └── fd: (11)-->(12-15) + │ │ └── filters (true) + │ ├── scan dept + │ │ ├── columns: dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ │ ├── key: (11) + │ │ └── fd: (11)-->(12-15) + │ └── filters (true) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +norm expect-not=JoinExtractFilter disable=(MergeSelectInnerJoin,PushFilterIntoJoinLeft,PushFilterIntoJoinRight,PushSelectIntoJoinLeft,PushSelectCondLeftIntoJoinLeftAndRight,FilterIntoJoin,FilterProjectTranspose,JoinAddRedundantSemiJoin,JoinConditionPush,SemiJoinFilterTranspose,SimplifyJoinFilters,EliminateProject,PruneScanCols,PruneJoinLeftCols,PruneJoinRightCols,PruneSemiAntiJoinRightCols,ProjectFilterTranspose) +SELECT * FROM emp INNER JOIN dept ON TRUE +---- +project + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8 deptno:11!null dname:12 loc:13 + ├── key: (1,11) + ├── fd: (1)-->(2-8), (11)-->(12,13) + └── inner-join (cross) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + ├── key: (1,11) + ├── fd: (1)-->(2-10), (11)-->(12-15) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 + │ ├── key: (1) + │ └── fd: (1)-->(2-10) + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ ├── key: (11) + │ └── fd: (11)-->(12-15) + └── filters (true) + +# -------------------------------------------------- +# JoinAddRedundantSemiJoin +# -------------------------------------------------- +norm expect=JoinAddRedundantSemiJoin disable=(MergeSelectInnerJoin,ProjectInnerJoinValues,InlineJoinConstantsLeft,InlineJoinConstantsRight,ProjectMerge,EliminateJoinNoColsRight,EliminateJoinNoColsLeft,PruneJoinRightCols,PruneJoinLeftCols,HoistJoinProjectLeft,HoistJoinProjectRight,JoinExtractFilter) +SELECT * +FROM emp +INNER JOIN dept ON emp.deptno = dept.deptno +---- +project + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── key: (1) + ├── fd: (1)-->(2-8,11-13), (11)-->(12,13), (8)==(11), (11)==(8) + └── inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8!null emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: (1)-->(2-10), (11)-->(12-15), (8)==(11), (11)==(8) + ├── semi-join (hash) + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 + │ ├── key: (1) + │ ├── fd: (1)-->(2-10) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 emp.crdb_internal_mvcc_timestamp:9 emp.tableoid:10 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-10) + │ ├── scan dept + │ │ ├── columns: dept.deptno:11!null + │ │ └── key: (11) + │ └── filters + │ └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 dept.crdb_internal_mvcc_timestamp:14 dept.tableoid:15 + │ ├── key: (11) + │ └── fd: (11)-->(12-15) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +# -------------------------------------------------- +# MinusMerge +# -------------------------------------------------- +norm expect=MinusMerge disable=(UnionMerge,IntersectMerge,SimplifyExcept,ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno FROM emp +EXCEPT +SELECT empno FROM emp WHERE sal > 1000 +EXCEPT +SELECT empno FROM emp WHERE sal < 500 +---- +except + ├── columns: empno:1!null + ├── left columns: empno:1!null + ├── right columns: empno:21 + ├── immutable + ├── key: (1) + ├── scan emp + │ ├── columns: empno:1!null + │ └── key: (1) + └── project + ├── columns: empno:21!null + ├── immutable + ├── key: (21) + ├── union + │ ├── columns: column31:31!null + │ ├── left columns: empno:11 + │ ├── right columns: empno:21 + │ ├── immutable + │ ├── key: (31) + │ ├── project + │ │ ├── columns: empno:11!null + │ │ ├── immutable + │ │ ├── key: (11) + │ │ └── select + │ │ ├── columns: empno:11!null sal:16!null + │ │ ├── immutable + │ │ ├── key: (11) + │ │ ├── fd: (11)-->(16) + │ │ ├── scan emp + │ │ │ ├── columns: empno:11!null sal:16 + │ │ │ ├── key: (11) + │ │ │ └── fd: (11)-->(16) + │ │ └── filters + │ │ └── sal:16 > 1000 [outer=(16), immutable, constraints=(/16: (/1000 - ]; tight)] + │ └── project + │ ├── columns: empno:21!null + │ ├── immutable + │ ├── key: (21) + │ └── select + │ ├── columns: empno:21!null sal:26!null + │ ├── immutable + │ ├── key: (21) + │ ├── fd: (21)-->(26) + │ ├── scan emp + │ │ ├── columns: empno:21!null sal:26 + │ │ ├── key: (21) + │ │ └── fd: (21)-->(26) + │ └── filters + │ └── sal:26 < 500 [outer=(26), immutable, constraints=(/26: (/NULL - /500); tight)] + └── projections + └── column31:31 [as=empno:21, outer=(31)] + +norm expect-not=MinusMerge disable=(UnionMerge,IntersectMerge,SimplifyExcept,ProjectFilterTranspose,JoinConditionPush,FilterProjectTranspose,JoinExtractFilter) +SELECT empno FROM emp +EXCEPT +SELECT empno FROM emp WHERE sal > 1000 +---- +except + ├── columns: empno:1!null + ├── left columns: empno:1!null + ├── right columns: empno:11 + ├── immutable + ├── key: (1) + ├── scan emp + │ ├── columns: empno:1!null + │ └── key: (1) + └── project + ├── columns: empno:11!null + ├── immutable + ├── key: (11) + └── select + ├── columns: empno:11!null sal:16!null + ├── immutable + ├── key: (11) + ├── fd: (11)-->(16) + ├── scan emp + │ ├── columns: empno:11!null sal:16 + │ ├── key: (11) + │ └── fd: (11)-->(16) + └── filters + └── sal:16 > 1000 [outer=(16), immutable, constraints=(/16: (/1000 - ]; tight)] diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt new file mode 100644 index 0000000..b6b4dda --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt @@ -0,0 +1,12 @@ +[AggregateExtractProject, Normalize] +(GroupBy + $input_5:* + $aggregations_6:* & (CanExtractProjectFromAggregate $aggregations_6) + $groupingPrivate_7:* +) +=> +(ConstructAggregateExtractProject + $input_5 + $aggregations_6 + $groupingPrivate_7 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateFilterTranspose.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateFilterTranspose.opt new file mode 100644 index 0000000..136d93f --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateFilterTranspose.opt @@ -0,0 +1,18 @@ +[AggregateFilterTranspose, Normalize] +(GroupBy + (Select + $input_0:* + $cond_1:* +) + $aggregations_3:* + $private_5:* & (FiltersBoundBy $cond_1 (GroupingCols $private_5)) +) +=> +(Select + (GroupBy + $input_0 + $aggregations_3 + $private_5 +) + $cond_1 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt new file mode 100644 index 0000000..f955163 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt @@ -0,0 +1,38 @@ +[AggregateJoinJoinRemove, Normalize] +(DistinctOn + (LeftJoin + (LeftJoin + $left_0:* + $middle_1:* + * + ) + $right_2:* + $rightFilters_3:* + ) + $aggregations_4:[] + $groupingPrivate_5:(GroupingPrivate $groupingCols_6:* $ordering_7:*) & + (ColsAreEmpty + (IntersectionCols + (OutputCols $middle_1) + (UnionCols + (FilterOuterCols $rightFilters_3) + $groupingCols_6 + ) + ) + ) & + (OrderingCanProjectCols + $ordering_7 + (UnionCols (OutputCols $left_0) (OutputCols $right_2)) + ) +) +=> +(DistinctOn + (LeftJoin + $left_0 + $right_2 + $rightFilters_3 + (EmptyJoinPrivate) + ) + $aggregations_4 + $groupingPrivate_5 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt new file mode 100644 index 0000000..6293544 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt @@ -0,0 +1,24 @@ +[AggregateJoinRemove, Normalize] +(DistinctOn + (LeftJoin + $left_0:* + * + * + ) + $aggregations_1:[] + $groupingPrivate_2:(GroupingPrivate $groupingCols_3:* $ordering_4:*) & + (ColsAreSubset + $groupingCols_3 + $leftCols_5:(OutputCols $left_0) + ) & + (OrderingCanProjectCols + $ordering_4 + $leftCols_5 + ) +) +=> +(DistinctOn + $left_0 + $aggregations_1 + $groupingPrivate_2 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectConstantToDummyJoin.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectConstantToDummyJoin.opt new file mode 100644 index 0000000..07cc0da --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectConstantToDummyJoin.opt @@ -0,0 +1,12 @@ +[AggregateProjectConstantToDummyJoin, Normalize] +(GroupBy + $input_0:(Project * * *) + $aggregations_1:* + $groupingPrivate_2:* & (HasConstantGroupingCols $input_0 $groupingPrivate_2) +) +=> +(ConstructAggregateProjectConstantToDummyJoin + $input_0 + $aggregations_1 + $groupingPrivate_2 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt new file mode 100644 index 0000000..4985f1a --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt @@ -0,0 +1,16 @@ +[AggregateProjectMerge, Normalize] +(GroupBy + $input_0:(Project + $input_1:* + * + * + ) + $aggregations_2:* + $groupingPrivate_3:* & (CanMergeProjectIntoAggregate $input_0 $groupingPrivate_3) +) +=> +(GroupBy + $input_1 + (MergeProjectIntoAggregate $input_0 $aggregations_2) + $groupingPrivate_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterAggregateTranspose.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterAggregateTranspose.opt new file mode 100644 index 0000000..b8a873f --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterAggregateTranspose.opt @@ -0,0 +1,18 @@ +[FilterAggregateTranspose, Normalize] +(Select + (GroupBy + $input_0:* + $aggregations_2:* + $private_4:* +) + $cond_5:* & (FiltersBoundBy $cond_5 (GroupingCols $private_4)) +) +=> +(GroupBy + (Select + $input_0 + $cond_5 +) + $aggregations_2 + $private_4 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterIntoJoin.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterIntoJoin.opt new file mode 100644 index 0000000..2b0cf79 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterIntoJoin.opt @@ -0,0 +1,17 @@ +[FilterIntoJoin, Normalize] +(Select + (InnerJoin + $input_0:* + $input_1:* + $cond_2:* + $private_3:* +) + $cond_7:* +) +=> +(InnerJoin + $input_0 + $input_1 + (ConcatFilters $cond_2 $cond_7) + $private_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterMerge.opt new file mode 100644 index 0000000..fd1e5e7 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterMerge.opt @@ -0,0 +1,13 @@ +[FilterMerge, Normalize] +(Select + (Select + $input_0:* + $cond_1:* +) + $cond_2:* +) +=> +(Select + $input_0 + (ConcatFilters $cond_1 $cond_2) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterProjectTranspose.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterProjectTranspose.opt new file mode 100644 index 0000000..a5b4122 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterProjectTranspose.opt @@ -0,0 +1,19 @@ +[FilterProjectTranspose, Normalize] +(Select + (Project + $input_0:* + $proj_1:* + $passthrough_2:* +) + $cond_3:* & + (FiltersBoundBy $cond_3 $inputCols_4:(OutputCols $input_0)) +) +=> +(Project + (Select + $input_0 + $cond_3 +) + $proj_1 + $passthrough_2 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt new file mode 100644 index 0000000..4421882 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt @@ -0,0 +1,11 @@ +[FilterReduceFalse, Normalize] +(Select + $input_0:* + $on_1:[ + ... + $item_2:(FiltersItem (False)) + ... + ] +) +=> +(ConstructEmptyValues (OutputCols $input_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceTrue.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceTrue.opt new file mode 100644 index 0000000..fe74434 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceTrue.opt @@ -0,0 +1,7 @@ +[FilterReduceTrue, Normalize] +(Select + $input_0:* + [] +) +=> +$input_0 diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterSetOpTranspose.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterSetOpTranspose.opt new file mode 100644 index 0000000..729afe5 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterSetOpTranspose.opt @@ -0,0 +1,26 @@ +[FilterSetOpTranspose, Normalize] +(Select + $input_0:(Union $left_1:* $right_2:* $colmap_3:*) + $filter_4:[ + ... + $item_5:* & + (CanMapOnSetOp $item_5) & + (IsBoundBy $item_5 $inputCols_6:(OutputCols $input_0)) + ... + ] +) +=> +(Select + (Union + (Select + $left_1 + [ (FiltersItem (MapSetOpFilterLeft $item_5 $colmap_3)) ] + ) + (Select + $right_2 + [ (FiltersItem (MapSetOpFilterRight $item_5 $colmap_3)) ] + ) + $colmap_3 + ) + (RemoveFiltersItem $filter_4 $item_5) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/IntersectMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/IntersectMerge.opt new file mode 100644 index 0000000..5c6bb7c --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/IntersectMerge.opt @@ -0,0 +1,20 @@ +[IntersectMerge, Normalize] +(Intersect + $left_5:(Intersect + $leftLeft_0:* + $leftRight_1:* + $innerPrivate_2:(SetPrivate $innerLeftCols_3:* $innerRightCols_4:* *) + ) + $right_6:* + $outerPrivate_7:(SetPrivate * $outerRightCols_8:* $outerOutCols_9:*) +) +=> +(Intersect + $leftLeft_0 + (Intersect + $leftRight_1 + $right_6 + (MakeSetPrivate $innerRightCols_4 $outerRightCols_8 $innerRightCols_4) + ) + (MakeSetPrivate $innerLeftCols_3 $innerRightCols_4 $outerOutCols_9) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinAddRedundantSemiJoin.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinAddRedundantSemiJoin.opt new file mode 100644 index 0000000..a2b87de --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinAddRedundantSemiJoin.opt @@ -0,0 +1,19 @@ +[JoinAddRedundantSemiJoin, Normalize] +(InnerJoin + $left_4:^(Values) + $right_5:* + $filters_7:* + $private_3:* & ^(IsRedundantSemiJoin $left_4 $right_5 $filters_7) +) +=> +(InnerJoin + (SemiJoin + $left_4 + $right_5 + $filters_7 + (EmptyJoinPrivate) + ) + $right_5 + $filters_7 + $private_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinCommute.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinCommute.opt new file mode 100644 index 0000000..8e31c46 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinCommute.opt @@ -0,0 +1,22 @@ +[JoinCommute, Normalize] +(InnerJoin + $left_4:* + $right_5:* + $on_6:* + $private_3:* & + (CanCommuteJoin $left_4 $right_5) +) +=> +(Project + (InnerJoin + $right_5 + $left_4 + $on_6 + (CommuteJoinFlags $private_3) + ) + (SwapJoinOutputColumns + (OutputCols $left_4) + (OutputCols $right_5) + ) + (MakeEmptyColSet) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt new file mode 100644 index 0000000..e73d9d4 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt @@ -0,0 +1,22 @@ +[JoinConditionPush, Normalize] +(InnerJoin + $left_0:* & ^(HasOuterCols $left_0) + $right_1:* & ^(HasOuterCols $right_1) + $on_2:* & + (HasBoundConditions + $on_2 + (OutputCols $left_0) + (OutputCols $right_1) + ) + $private_3:* +) +=> +(InnerJoin + (Select $left_0 (ExtractBoundConditions $on_2 (OutputCols $left_0))) + (Select $right_1 (ExtractBoundConditions $on_2 (OutputCols $right_1))) + (ExtractUnboundConditions + (ExtractUnboundConditions $on_2 (OutputCols $left_0)) + (OutputCols $right_1) + ) + $private_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt new file mode 100644 index 0000000..f5b3c43 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt @@ -0,0 +1,15 @@ +[JoinExtractFilter, Normalize] +(InnerJoin + $left_7:* + $right_8:* + $on_9:* & + (CanExtractJoinFilter $left_7 $right_8 $on_9) + $private_3:* +) +=> +(ConstructJoinExtractFilterResult + $left_7 + $right_8 + $on_9 + $private_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinPushTransitivePredicates.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinPushTransitivePredicates.opt new file mode 100644 index 0000000..719c56a --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinPushTransitivePredicates.opt @@ -0,0 +1,17 @@ +[JoinPushTransitivePredicates, Normalize] +(Select + (InnerJoin + $input_0:* + $input_1:* + $cond_2:* & ^(IsFilterEmpty $cond_2) + $private_3:* +) + $cond_7:* +) +=> +(InnerJoin + $input_0 + $input_1 + (ConcatFilters $cond_2 $cond_7) + $private_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt new file mode 100644 index 0000000..12ce5eb --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt @@ -0,0 +1,21 @@ +[JoinReduceFalse, Normalize] +(InnerJoin + $input_0:* + $input_1:* + $on_2:[ + ... + $item_3:(FiltersItem + (And * (False)) + ) + ... + ] & + ^(IsFilterFalse $on_2) + $private_4:* +) +=> +(InnerJoin + $input_0 + $input_1 + [ (FiltersItem (False)) ] + $private_4 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt new file mode 100644 index 0000000..c894d23 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt @@ -0,0 +1,18 @@ +[JoinReduceTrue, Normalize] +(InnerJoin + $input_0:* + $input_1:* + $on_2:[ + ... + $item_3:(FiltersItem (True)) + ... + ] + $private_4:* +) +=> +(InnerJoin + $input_0 + $input_1 + (RemoveFiltersItem $on_2 $item_3) + $private_4 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt new file mode 100644 index 0000000..8e01f74 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt @@ -0,0 +1,18 @@ +[MinusMerge, Normalize] +(Except + $left_0:(Except + $leftLeft_1:* + $leftRight_2:* + $innerPrivate_3:* + ) + $right_4:* + $outerPrivate_5:* +) +=> +(ConstructMinusMergeResult + $leftLeft_1 + $leftRight_2 + $right_4 + $innerPrivate_3 + $outerPrivate_5 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectAggregateMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectAggregateMerge.opt new file mode 100644 index 0000000..657e539 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectAggregateMerge.opt @@ -0,0 +1,27 @@ +[ProjectAggregateMerge, Normalize] +(Project + $input_3:(GroupBy + $input_0:* + $aggregations_1:* + $groupingPrivate_2:* + ) + $projections_4:* + $passthrough_5:* & + (CanPruneAggCols + $aggregations_1 + $needed_6:(UnionCols + (ProjectionOuterCols $projections_4) + $passthrough_5 + ) + ) +) +=> +(Project + (GroupBy + $input_0 + (PruneAggCols $aggregations_1 $needed_6) + $groupingPrivate_2 + ) + $projections_4 + $passthrough_5 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectFilterTranspose.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectFilterTranspose.opt new file mode 100644 index 0000000..9921d3b --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectFilterTranspose.opt @@ -0,0 +1,18 @@ +[ProjectFilterTranspose, Normalize] +(Project + (Select + $input_0:* + $cond_1:* +) + $proj_2:* + $passthrough_3:* & (FiltersBoundBy $cond_1 $passthrough_3) +) +=> +(Select + (Project + $input_0 + $proj_2 + $passthrough_3 +) + $cond_1 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectMerge.opt new file mode 100644 index 0000000..cc0543b --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/ProjectMerge.opt @@ -0,0 +1,23 @@ +[ProjectMerge, Normalize] +(Project + $input:(Project + $input_1:* + $proj_2:* + $innerPassthrough_3:*) + $proj_0:* & + (CanMergeProjections $proj_0 $proj_2) + $passthrough_4:* +) +=> +(Project + $input_1 + (MergeProjections + $proj_0 + $proj_2 + $passthrough_4 + ) + (DifferenceCols + $innerPassthrough_3 + (ProjectionCols $proj_2) + ) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt new file mode 100644 index 0000000..c7a31cb --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt @@ -0,0 +1,7 @@ +[PruneEmptyFilter, Normalize] +(Select + $input_2:* & (HasZeroRows $input_2) + $filters:* +) +=> +$input_2 diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt new file mode 100644 index 0000000..43bed9c --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt @@ -0,0 +1,7 @@ +[PruneEmptyIntersect, Normalize] +(Intersect + $left_0:* + $right_1:* & (HasZeroRows $right_1) +) +=> +(ConstructEmptyValues (OutputCols $left_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt new file mode 100644 index 0000000..b1b40fe --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt @@ -0,0 +1,7 @@ +[PruneEmptyMinus, Normalize] +(Except + $left_0:* & (HasZeroRows $left_0) + $right:* +) +=> +(ConstructEmptyValues (OutputCols $left_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt new file mode 100644 index 0000000..57a6da8 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt @@ -0,0 +1,8 @@ +[PruneEmptyProject, Normalize] +(Project + $input_0:* & (HasZeroRows $input_0) + $projections:* + $passthrough:* +) +=> +(ConstructEmptyValues (UnionCols (ProjectionCols $projections) $passthrough)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt new file mode 100644 index 0000000..95416c0 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt @@ -0,0 +1,8 @@ +[PruneEmptyUnion, Normalize] +(Union + $left_0:* & (HasZeroRows $left_0) + $right_1:* & (HasZeroRows $right_1) + $private_2:(SetPrivate * * $outCols_3:*) +) +=> +(ConstructEmptyValues (ColListToSet $outCols_3)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/SemiJoinFilterTranspose.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/SemiJoinFilterTranspose.opt new file mode 100644 index 0000000..9b8dec3 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/SemiJoinFilterTranspose.opt @@ -0,0 +1,20 @@ +[SemiJoinFilterTranspose, Normalize] +(Select + (SemiJoin + $input_0:* + $input_1:* + $cond_2:* + $private_3:* +) + $cond_4:* +) +=> +(SemiJoin + (Select + $input_0 + $cond_4 +) + $input_1 + $cond_2 + $private_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/UnionMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/UnionMerge.opt new file mode 100644 index 0000000..e50cf28 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/UnionMerge.opt @@ -0,0 +1,20 @@ +[UnionMerge, Normalize] +(Union + $left_6:(Union + $leftLeft_0:* + $leftRight_1:* + $innerPrivate_2:(SetPrivate $innerLeftCols_3:* $innerRightCols_4:* $innerOutCols_5:*) + ) + $right_7:* + $outerPrivate_8:(SetPrivate * $outerRightCols_9:* $outerOutCols_10:*) +) +=> +(Union + $leftLeft_0 + (Union + $leftRight_1 + $right_7 + (MakeSetPrivate $innerRightCols_4 $outerRightCols_9 $innerOutCols_5) + ) + (MakeSetPrivate $innerLeftCols_3 $innerOutCols_5 $outerOutCols_10) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/UnionPullUpConstants.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/UnionPullUpConstants.opt new file mode 100644 index 0000000..26ec24b --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/UnionPullUpConstants.opt @@ -0,0 +1,33 @@ +[UnionPullUpConstants, Normalize] +(UnionAll + $left_4:(Project + $leftInput_1:* + $leftProjections_2:* + $leftPassthrough_3:* + ) + $right_9:(Project + $rightInput_6:* + $rightProjections_7:* + $rightPassthrough_8:* + ) + $private_10:(SetPrivate $leftCols_11:* $rightCols_12:* $outCols_13:*) & + (HasMatchingConstantsFromUnion + $leftProjections_2 + $rightProjections_7 + $leftCols_11 + $rightCols_12 + $outCols_13 + ) +) +=> +(UnionPullUpConstantsReplace + $left_4 + $right_9 + $leftProjections_2 + $rightProjections_7 + $private_10 + $leftInput_1 + $rightInput_6 + $leftPassthrough_3 + $rightPassthrough_8 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/UnionToDistinct.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/UnionToDistinct.opt new file mode 100644 index 0000000..23a6b81 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/UnionToDistinct.opt @@ -0,0 +1,29 @@ +[UnionToDistinct, Normalize] +(Union + $left_1:* + $right_3:* + $private_4:(SetPrivate $leftCols_5:* $rightCols_6:* $outCols_7:*) & + (Let + ($keyCols_8 $ok_9):(CanConvertUnionToDistinctUnionAll + $leftCols_5 + $rightCols_6 + ) + $ok_9 + ) +) +=> +(DistinctOn + (UnionAll $left_1 $right_3 $private_4) + (MakeAggCols + ConstAgg + (TranslateColSet + (DifferenceCols (OutputCols $left_1) $keyCols_8) + $leftCols_5 + $outCols_7 + ) + ) + (MakeGrouping + (TranslateColSet $keyCols_8 $leftCols_5 $outCols_7) + (EmptyOrdering) + ) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/HelperFunctions.go b/src/main/java/org/qed/Backends/Cockroach/HelperFunctions.go new file mode 100644 index 0000000..946b34d --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/HelperFunctions.go @@ -0,0 +1,916 @@ +package norm + +import ( + "github.com/cockroachdb/cockroach/pkg/sql/opt" + "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/intsets" + "github.com/cockroachdb/errors" +) + +func (c *CustomFuncs) HasConstantGroupingCols( + input memo.RelExpr, groupingPrivate *memo.GroupingPrivate, +) bool { + _, _, _, ok := c.extractConstantGroupingColsAndBuildDummy(input, groupingPrivate) + return ok +} + +func (c *CustomFuncs) extractConstantGroupingColsAndBuildDummy( + input memo.RelExpr, groupingPrivate *memo.GroupingPrivate, +) (constantCols opt.ColSet, constantValues memo.ScalarListExpr, dummyCols opt.ColList, ok bool) { + project, ok := input.(*memo.ProjectExpr) + if !ok { + return opt.ColSet{}, nil, nil, false + } + + groupingCols := groupingPrivate.GroupingCols + constantCols = opt.ColSet{} + constantToValue := make(map[opt.ColumnID]opt.ScalarExpr) + + for i := range project.Projections { + item := &project.Projections[i] + if groupingCols.Contains(item.Col) && opt.IsConstValueOp(item.Element) { + constantCols.Add(item.Col) + constantToValue[item.Col] = item.Element + } + } + + if constantCols.Empty() { + return opt.ColSet{}, nil, nil, false + } + + md := c.mem.Metadata() + constantColList := constantCols.ToList() + constantValues = make(memo.ScalarListExpr, len(constantColList)) + dummyCols = make(opt.ColList, len(constantColList)) + for i, col := range constantColList { + constantValues[i] = constantToValue[col] + dummyCols[i] = md.AddColumn("", constantToValue[col].DataType()) + } + + return constantCols, constantValues, dummyCols, true +} + +func (c *CustomFuncs) GetConstantGroupingCols( + input memo.RelExpr, groupingPrivate *memo.GroupingPrivate, +) opt.ColSet { + constantCols, _, _, _ := c.extractConstantGroupingColsAndBuildDummy(input, groupingPrivate) + return constantCols +} + +func (c *CustomFuncs) GetConstantValues( + input memo.RelExpr, groupingPrivate *memo.GroupingPrivate, +) memo.ScalarListExpr { + _, constantValues, _, _ := c.extractConstantGroupingColsAndBuildDummy(input, groupingPrivate) + return constantValues +} + +func (c *CustomFuncs) GetDummyCols( + input memo.RelExpr, groupingPrivate *memo.GroupingPrivate, +) opt.ColList { + _, _, dummyCols, _ := c.extractConstantGroupingColsAndBuildDummy(input, groupingPrivate) + return dummyCols +} + +func (c *CustomFuncs) ConstructDummyValuesTable( + constantValues memo.ScalarListExpr, dummyCols opt.ColList, +) memo.RelExpr { + if len(constantValues) == 0 || len(dummyCols) == 0 { + panic(errors.AssertionFailedf("ConstructDummyValuesTable called with empty constantValues or dummyCols")) + } + tupleTypes := make([]*types.T, len(constantValues)) + for i := range constantValues { + tupleTypes[i] = constantValues[i].DataType() + } + tupleTyp := types.MakeTuple(tupleTypes) + tuple := c.f.ConstructTuple(constantValues, tupleTyp) + rows := memo.ScalarListExpr{tuple} + return c.f.ConstructValues(rows, &memo.ValuesPrivate{ + Cols: dummyCols, + ID: c.mem.Metadata().NextUniqueID(), + }) +} + +func (c *CustomFuncs) RemapProjectionsForDummyJoin( + projections memo.ProjectionsExpr, + constantCols opt.ColSet, + dummyCols opt.ColList, +) memo.ProjectionsExpr { + constantColList := constantCols.ToList() + constantToDummy := make(map[opt.ColumnID]opt.ColumnID) + for i := range constantColList { + if i < len(dummyCols) { + constantToDummy[constantColList[i]] = dummyCols[i] + } + } + + newProjections := make(memo.ProjectionsExpr, 0, len(projections)) + for i := range projections { + item := &projections[i] + if dummyCol, ok := constantToDummy[item.Col]; ok { + newProjections = append(newProjections, c.f.ConstructProjectionsItem( + c.f.ConstructVariable(dummyCol), + item.Col, + )) + } else { + newProjections = append(newProjections, *item) + } + } + return newProjections +} + +func (c *CustomFuncs) RemapGroupingColsForDummyJoin( + groupingPrivate *memo.GroupingPrivate, + constantCols opt.ColSet, + dummyCols opt.ColList, +) *memo.GroupingPrivate { + constantColList := constantCols.ToList() + constantToDummy := make(map[opt.ColumnID]opt.ColumnID) + for i := range constantColList { + if i < len(dummyCols) { + constantToDummy[constantColList[i]] = dummyCols[i] + } + } + + newGroupingCols := opt.ColSet{} + for col, ok := groupingPrivate.GroupingCols.Next(0); ok; col, ok = groupingPrivate.GroupingCols.Next(col + 1) { + if dummyCol, isConst := constantToDummy[col]; isConst { + newGroupingCols.Add(dummyCol) + } else { + newGroupingCols.Add(col) + } + } + + newPrivate := *groupingPrivate + newPrivate.GroupingCols = newGroupingCols + return &newPrivate +} + +func (c *CustomFuncs) ConstructAggregateProjectConstantToDummyJoin( + input memo.RelExpr, + aggregations memo.AggregationsExpr, + groupingPrivate *memo.GroupingPrivate, +) memo.RelExpr { + project := input.(*memo.ProjectExpr) + + constantCols, constantValues, dummyCols, ok := c.extractConstantGroupingColsAndBuildDummy(project, groupingPrivate) + if !ok { + panic(errors.AssertionFailedf("should have matched")) + } + + values := c.ConstructDummyValuesTable(constantValues, dummyCols) + + joinPrivate := &memo.JoinPrivate{} + + filters := memo.FiltersExpr{{Condition: c.f.ConstructTrue()}} + join := c.f.ConstructInnerJoin(project.Input, values, filters, joinPrivate) + + newProjections := c.RemapProjectionsForDummyJoin(project.Projections, constantCols, dummyCols) + newProject := c.f.ConstructProject(join, newProjections, project.Passthrough) + + return c.f.ConstructGroupBy(newProject, aggregations, groupingPrivate) +} + +func (c *CustomFuncs) CanMergeProjectIntoAggregate( + input memo.RelExpr, + groupingPrivate *memo.GroupingPrivate, +) bool { + project, ok := input.(*memo.ProjectExpr) + if !ok { + return false + } + + if !groupingPrivate.GroupingCols.SubsetOf(project.Passthrough) { + return false + } + + if !groupingPrivate.Ordering.ColSet().SubsetOf(project.Passthrough) { + return false + } + + if len(project.Projections) == 0 { + return false + } + + return true +} + +func (c *CustomFuncs) MergeProjectIntoAggregate( + input memo.RelExpr, + aggregations memo.AggregationsExpr, +) memo.AggregationsExpr { + project := input.(*memo.ProjectExpr) + + colToExpr := make(map[opt.ColumnID]opt.ScalarExpr) + for i := range project.Projections { + item := &project.Projections[i] + colToExpr[item.Col] = item.Element + } + + newAggs := make(memo.AggregationsExpr, len(aggregations)) + for i := range aggregations { + agg := aggregations[i].Agg + + var replace ReplaceFunc + replace = func(e opt.Expr) opt.Expr { + if v, ok := e.(*memo.VariableExpr); ok { + if expr, found := colToExpr[v.Col]; found { + return expr + } + } + return c.f.Replace(e, replace) + } + + newAgg := replace(agg).(opt.ScalarExpr) + newAggs[i] = c.f.ConstructAggregationsItem(newAgg, aggregations[i].Col) + } + + return newAggs +} + +func (c *CustomFuncs) ExtractMatchingConstantsFromUnion( + leftProjections memo.ProjectionsExpr, + rightProjections memo.ProjectionsExpr, + leftCols opt.ColList, + rightCols opt.ColList, + outCols opt.ColList, +) (constantPositions []int, constantValues memo.ScalarListExpr, ok bool) { + if len(leftCols) != len(rightCols) || len(leftCols) != len(outCols) { + return nil, nil, false + } + + leftColToProj := make(map[opt.ColumnID]int) + for i := range leftProjections { + leftColToProj[leftProjections[i].Col] = i + } + rightColToProj := make(map[opt.ColumnID]int) + for i := range rightProjections { + rightColToProj[rightProjections[i].Col] = i + } + + constantPositions = make([]int, 0) + constantValues = make(memo.ScalarListExpr, 0) + + for outIdx := range outCols { + leftCol := leftCols[outIdx] + rightCol := rightCols[outIdx] + + leftProjIdx, leftHasProj := leftColToProj[leftCol] + rightProjIdx, rightHasProj := rightColToProj[rightCol] + + if !leftHasProj || !rightHasProj { + continue + } + + leftItem := &leftProjections[leftProjIdx] + rightItem := &rightProjections[rightProjIdx] + + if opt.IsConstValueOp(leftItem.Element) && opt.IsConstValueOp(rightItem.Element) { + if c.IsConstValueEqual(leftItem.Element, rightItem.Element) { + constantPositions = append(constantPositions, outIdx) + constantValues = append(constantValues, leftItem.Element) + } + } + } + + if len(constantPositions) == 0 { + return nil, nil, false + } + + return constantPositions, constantValues, true +} + +func (c *CustomFuncs) MakeColSetFromPositions( + positions []int, + colList opt.ColList, +) opt.ColSet { + result := opt.ColSet{} + for _, pos := range positions { + if pos < len(colList) { + result.Add(colList[pos]) + } + } + return result +} + +func (c *CustomFuncs) ComputeNeededColsForUnionPullUp( + constantPositions []int, + outCols opt.ColList, +) opt.ColSet { + constantCols := c.MakeColSetFromPositions(constantPositions, outCols) + return outCols.ToSet().Difference(constantCols) +} + +func (c *CustomFuncs) AddConstantsToProjections( + constantPositions []int, + constantValues memo.ScalarListExpr, + outCols opt.ColList, +) memo.ProjectionsExpr { + if len(constantPositions) != len(constantValues) { + panic(errors.AssertionFailedf("constantPositions and constantValues must have same length")) + } + + projections := make(memo.ProjectionsExpr, 0, len(constantPositions)) + for i, pos := range constantPositions { + if pos >= len(outCols) { + panic(errors.AssertionFailedf("position %d out of range for outCols", pos)) + } + outCol := outCols[pos] + projections = append(projections, c.f.ConstructProjectionsItem(constantValues[i], outCol)) + } + + return projections +} + +func (c *CustomFuncs) HasMatchingConstantsFromUnion( + leftProjections memo.ProjectionsExpr, + rightProjections memo.ProjectionsExpr, + leftCols opt.ColList, + rightCols opt.ColList, + outCols opt.ColList, +) bool { + _, _, ok := c.ExtractMatchingConstantsFromUnion(leftProjections, rightProjections, leftCols, rightCols, outCols) + return ok +} + +func (c *CustomFuncs) UnionPullUpConstantsReplace( + left memo.RelExpr, + right memo.RelExpr, + leftProjections memo.ProjectionsExpr, + rightProjections memo.ProjectionsExpr, + private *memo.SetPrivate, + leftInput memo.RelExpr, + rightInput memo.RelExpr, + leftPassthrough opt.ColSet, + rightPassthrough opt.ColSet, +) memo.RelExpr { + constantPositions, constantValues, ok := c.ExtractMatchingConstantsFromUnion( + leftProjections, rightProjections, + private.LeftCols, private.RightCols, private.OutCols, + ) + if !ok { + panic(errors.AssertionFailedf("HasMatchingConstantsFromUnion should have returned true")) + } + + neededCols := c.ComputeNeededColsForUnionPullUp(constantPositions, private.OutCols) + + leftProject := left.(*memo.ProjectExpr) + rightProject := right.(*memo.ProjectExpr) + + neededLeftCols := c.NeededColMapLeft(neededCols, private) + neededRightCols := c.NeededColMapRight(neededCols, private) + prunedLeft := c.PruneCols(leftProject.Input, neededLeftCols) + prunedRight := c.PruneCols(rightProject.Input, neededRightCols) + + adjustedPrivate := c.PruneSetPrivate(neededCols, private) + + union := c.f.ConstructUnionAll(prunedLeft, prunedRight, adjustedPrivate) + + unionOutputCols := adjustedPrivate.OutCols + mergedProjections := make(memo.ProjectionsExpr, 0, len(private.OutCols)) + passthrough := opt.ColSet{} + + constantProjMap := make(map[opt.ColumnID]memo.ProjectionsItem) + for i, pos := range constantPositions { + if pos < len(private.OutCols) { + outCol := private.OutCols[pos] + constantProjMap[outCol] = c.f.ConstructProjectionsItem(constantValues[i], outCol) + } + } + + unionColIdx := 0 + for _, outCol := range private.OutCols { + if constantProj, isConst := constantProjMap[outCol]; isConst { + mergedProjections = append(mergedProjections, constantProj) + } else if unionColIdx < len(unionOutputCols) { + unionCol := unionOutputCols[unionColIdx] + if unionCol == outCol { + passthrough.Add(outCol) + } else { + mergedProjections = append(mergedProjections, c.f.ConstructProjectionsItem( + c.f.ConstructVariable(unionCol), + outCol, + )) + } + unionColIdx++ + } + } + + return c.f.ConstructProject(union, mergedProjections, passthrough) +} + +func (c *CustomFuncs) ColListToSet(colList opt.ColList) opt.ColSet { + return colList.ToSet() +} + +func (c *CustomFuncs) CanCommuteJoin(left, right memo.RelExpr) bool { + return c.OutputCols(left).Len() <= c.OutputCols(right).Len() +} + +func (c *CustomFuncs) SwapJoinOutputColumns( + leftCols opt.ColSet, + rightCols opt.ColSet, +) memo.ProjectionsExpr { + projections := make(memo.ProjectionsExpr, 0, leftCols.Len()+rightCols.Len()) + md := c.mem.Metadata() + + for col, ok := rightCols.Next(0); ok; col, ok = rightCols.Next(col + 1) { + colMeta := md.ColumnMeta(col) + newCol := md.AddColumn(colMeta.Alias, colMeta.Type) + projections = append(projections, c.f.ConstructProjectionsItem( + c.f.ConstructVariable(col), + newCol, + )) + } + + for col, ok := leftCols.Next(0); ok; col, ok = leftCols.Next(col + 1) { + colMeta := md.ColumnMeta(col) + newCol := md.AddColumn(colMeta.Alias, colMeta.Type) + projections = append(projections, c.f.ConstructProjectionsItem( + c.f.ConstructVariable(col), + newCol, + )) + } + + return projections +} + +func (c *CustomFuncs) HasBoundConditions( + filters memo.FiltersExpr, + leftCols opt.ColSet, + rightCols opt.ColSet, +) bool { + for i := range filters { + if c.IsBoundBy(&filters[i], leftCols) || c.IsBoundBy(&filters[i], rightCols) { + return true + } + } + return false +} + +func (c *CustomFuncs) IsFilterTrue(filters memo.FiltersExpr) bool { + if len(filters) == 0 { + return true + } + for i := range filters { + condition := filters[i].Condition + if condition.Op() != opt.TrueOp { + return false + } + } + return true +} + +func (c *CustomFuncs) CanExtractJoinFilter( + left memo.RelExpr, + right memo.RelExpr, + on memo.FiltersExpr, +) bool { + if c.IsFilterTrue(on) || c.IsFilterEmpty(on) { + return false + } + if c.HasOuterCols(left) || c.HasOuterCols(right) { + return false + } + + allCols := left.Relational().OutputCols.Union(right.Relational().OutputCols) + hasNonTrueCondition := false + for i := range on { + if on[i].Condition.Op() != opt.TrueOp { + hasNonTrueCondition = true + if !c.IsBoundBy(&on[i], allCols) { + return false + } + } + } + return hasNonTrueCondition +} + +func (c *CustomFuncs) ConstructJoinExtractFilterResult( + left memo.RelExpr, + right memo.RelExpr, + on memo.FiltersExpr, + private *memo.JoinPrivate, +) memo.RelExpr { + var disabledRules intsets.Fast + disabledRules.Add(int(opt.FilterIntoJoin)) + disabledRules.Add(int(opt.MergeSelectInnerJoin)) + disabledRules.Add(int(opt.JoinPushTransitivePredicates)) + + var result memo.RelExpr + c.f.DisableOptimizationRulesTemporarily(disabledRules, func() { + result = c.f.ConstructSelect( + c.f.ConstructInnerJoin(left, right, memo.EmptyFiltersExpr, private), + on, + ) + }) + return result +} + +func (c *CustomFuncs) CanExtractProjectFromAggregate( + aggregations memo.AggregationsExpr, +) bool { + for i := range aggregations { + agg := aggregations[i].Agg + if agg.ChildCount() > 0 { + arg := agg.Child(0) + if scalarArg, ok := arg.(opt.ScalarExpr); ok { + if _, ok := scalarArg.(*memo.VariableExpr); !ok { + if !opt.IsConstValueOp(scalarArg) { + return true + } + } + } + } + } + return false +} + +func (c *CustomFuncs) ExtractProjectFromAggregate( + input memo.RelExpr, + aggregations memo.AggregationsExpr, +) memo.RelExpr { + inputCols := c.OutputCols(input) + + var pb projectBuilder + pb.init(c, inputCols) + + for i := range aggregations { + agg := aggregations[i].Agg + if agg.ChildCount() > 0 { + arg := agg.Child(0) + if scalarArg, ok := arg.(opt.ScalarExpr); ok { + pb.add(scalarArg) + } + } + } + + return pb.buildProject(input) +} + +func (c *CustomFuncs) RemapAggregationsAfterExtractProject( + aggregations memo.AggregationsExpr, + input memo.RelExpr, +) memo.AggregationsExpr { + inputCols := c.OutputCols(input) + var pb projectBuilder + pb.init(c, inputCols) + + exprToVar := make(map[opt.ScalarExpr]opt.ScalarExpr) + for i := range aggregations { + agg := aggregations[i].Agg + if agg.ChildCount() > 0 { + arg := agg.Child(0) + if scalarArg, ok := arg.(opt.ScalarExpr); ok { + if _, exists := exprToVar[scalarArg]; !exists { + varExpr := pb.add(scalarArg) + exprToVar[scalarArg] = varExpr + } + } + } + } + + newAggs := make(memo.AggregationsExpr, len(aggregations)) + for i := range aggregations { + agg := aggregations[i].Agg + var newArg opt.ScalarExpr + + if agg.ChildCount() > 0 { + arg := agg.Child(0) + if scalarArg, ok := arg.(opt.ScalarExpr); ok { + if varExpr, exists := exprToVar[scalarArg]; exists { + newArg = varExpr + } else { + newArg = scalarArg + } + } else { + newArg = nil + } + } + + var newAgg opt.ScalarExpr + if newArg != nil && agg.ChildCount() > 0 { + var replace ReplaceFunc + replace = func(e opt.Expr) opt.Expr { + if e == agg.Child(0) { + return newArg + } + return c.f.Replace(e, replace) + } + newAgg = replace(agg).(opt.ScalarExpr) + } else { + newAgg = agg + } + + newAggs[i] = c.f.ConstructAggregationsItem(newAgg, aggregations[i].Col) + } + + return newAggs +} + +func (c *CustomFuncs) ConstructAggregateExtractProject( + input memo.RelExpr, + aggregations memo.AggregationsExpr, + groupingPrivate *memo.GroupingPrivate, +) memo.RelExpr { + inputCols := c.OutputCols(input) + + var pb projectBuilder + pb.init(c, inputCols) + + // exprToVar deduplicates: multiple aggregates using the same expression share + // one projected column rather than each getting a separate copy. + exprToVar := make(map[opt.ScalarExpr]opt.ScalarExpr) + + newAggs := make(memo.AggregationsExpr, len(aggregations)) + for i := range aggregations { + agg := aggregations[i].Agg + var newArg opt.ScalarExpr + + if agg.ChildCount() > 0 { + arg := agg.Child(0) + if scalarArg, ok := arg.(opt.ScalarExpr); ok { + if varExpr, exists := exprToVar[scalarArg]; exists { + newArg = varExpr + } else { + newArg = pb.add(scalarArg) + exprToVar[scalarArg] = newArg + } + } else { + newArg = nil + } + } + + var newAgg opt.ScalarExpr + if newArg != nil && agg.ChildCount() > 0 { + var replace ReplaceFunc + replace = func(e opt.Expr) opt.Expr { + if e == agg.Child(0) { + return newArg + } + return c.f.Replace(e, replace) + } + newAgg = replace(agg).(opt.ScalarExpr) + } else { + newAgg = agg + } + + newAggs[i] = c.f.ConstructAggregationsItem(newAgg, aggregations[i].Col) + } + + newProject := pb.buildProject(input) + + var result memo.RelExpr + var disabledRules intsets.Fast + disabledRules.Add(int(opt.AggregateProjectMerge)) + c.f.DisableOptimizationRulesTemporarily(disabledRules, func() { + result = c.f.ConstructGroupBy(newProject, newAggs, groupingPrivate) + }) + return result +} + +func (c *CustomFuncs) IsRedundantSemiJoin( + left memo.RelExpr, + right memo.RelExpr, + filters memo.FiltersExpr, +) bool { + currentLeft := left + for { + if proj, ok := currentLeft.(*memo.ProjectExpr); ok { + currentLeft = proj.Input + continue + } + if sel, ok := currentLeft.(*memo.SelectExpr); ok { + currentLeft = sel.Input + continue + } + break + } + + semi, ok := currentLeft.(*memo.SemiJoinExpr) + if !ok { + return false + } + + peel := func(e memo.RelExpr) memo.RelExpr { + current := e + for { + if proj, ok := current.(*memo.ProjectExpr); ok { + current = proj.Input + continue + } + if sel, ok := current.(*memo.SelectExpr); ok { + current = sel.Input + continue + } + return current + } + } + + baseRight := peel(right) + baseSemiRight := peel(semi.Right) + + if baseRight == baseSemiRight { + return true + } + + scanRight, okRight := baseRight.(*memo.ScanExpr) + scanSemiRight, okSemiRight := baseSemiRight.(*memo.ScanExpr) + + if okRight && okSemiRight { + return scanRight.Table == scanSemiRight.Table + } + + return false +} + +func (c *CustomFuncs) IsVariable(scalar opt.ScalarExpr) bool { + return scalar.Op() == opt.VariableOp +} + +func (c *CustomFuncs) AggArg(agg opt.ScalarExpr) opt.ScalarExpr { + if agg.ChildCount() > 0 { + if arg, ok := agg.Child(0).(opt.ScalarExpr); ok { + return arg + } + } + return nil +} + +func (c *CustomFuncs) ReplaceAggArg(agg, newArg opt.ScalarExpr) opt.ScalarExpr { + var replace ReplaceFunc + replace = func(e opt.Expr) opt.Expr { + if e == agg.Child(0) { + return newArg + } + return c.f.Replace(e, replace) + } + return replace(agg).(opt.ScalarExpr) +} + +func (c *CustomFuncs) IsSemiJoin(input memo.RelExpr) bool { + switch input.Op() { + case opt.SemiJoinOp, opt.SemiJoinApplyOp: + return true + default: + return false + } +} + +func (c *CustomFuncs) BindFiltersToProjections( + projections memo.ProjectionsExpr, passthrough opt.ColSet, filters memo.FiltersExpr, +) memo.FiltersExpr { + var colMap opt.ColMap + for col, ok := passthrough.Next(0); ok; col, ok = passthrough.Next(col + 1) { + colMap.Set(int(col), int(col)) + } + for i := range projections { + from := projections[i].Element.(*memo.VariableExpr).Col + to := projections[i].Col + colMap.Set(int(from), int(to)) + } + newFilters := make(memo.FiltersExpr, len(filters)) + for i := range filters { + newCondition := c.f.RemapCols(filters[i].Condition, colMap) + newFilters[i] = c.f.ConstructFiltersItem(newCondition) + } + return newFilters +} + + +func (c *CustomFuncs) AllFiltersCanMapOnSetOp(filters memo.FiltersExpr) bool { + for i := range filters { + if !c.CanMapOnSetOp(&filters[i]) { + return false + } + } + return true +} + +func (c *CustomFuncs) MapSetOpFiltersLeft( + filters memo.FiltersExpr, set *memo.SetPrivate, +) memo.FiltersExpr { + newFilters := make(memo.FiltersExpr, len(filters)) + for i := range filters { + newCondition := c.MapSetOpFilterLeft(&filters[i], set) + newFilters[i] = c.f.ConstructFiltersItem(newCondition) + } + return newFilters +} + +func (c *CustomFuncs) MapSetOpFiltersRight( + filters memo.FiltersExpr, set *memo.SetPrivate, +) memo.FiltersExpr { + newFilters := make(memo.FiltersExpr, len(filters)) + for i := range filters { + newCondition := c.MapSetOpFilterRight(&filters[i], set) + newFilters[i] = c.f.ConstructFiltersItem(newCondition) + } + return newFilters +} + +func (c *CustomFuncs) MakeUnionPrivateForExcept(pInner, pOuter *memo.SetPrivate) *memo.SetPrivate { + if len(pInner.RightCols) != len(pOuter.RightCols) { + panic(errors.AssertionFailedf("invalid SetPrivate shapes for Except-minus merge: inner.RightCols and outer.RightCols must have same length")) + } + leftCols := make(opt.ColList, len(pInner.RightCols)) + copy(leftCols, pInner.RightCols) + + rightCols := make(opt.ColList, len(pOuter.RightCols)) + copy(rightCols, pOuter.RightCols) + + outCols := make(opt.ColList, len(pInner.RightCols)) + copy(outCols, pInner.RightCols) + + return &memo.SetPrivate{ + LeftCols: leftCols, + RightCols: rightCols, + OutCols: outCols, + } +} + +func (c *CustomFuncs) ConstructMinusMergeResult( + leftLeft memo.RelExpr, + leftRight memo.RelExpr, + right memo.RelExpr, + innerPrivate *memo.SetPrivate, + outerPrivate *memo.SetPrivate, +) memo.RelExpr { + md := c.mem.Metadata() + + outCols := make(opt.ColList, len(innerPrivate.RightCols)) + for i, col := range innerPrivate.RightCols { + outCols[i] = md.AddColumn("", md.ColumnMeta(col).Type) + } + + unionPrivate := &memo.SetPrivate{ + LeftCols: innerPrivate.RightCols, + RightCols: outerPrivate.RightCols, + OutCols: outCols, + } + + union := c.f.ConstructUnion(leftRight, right, unionPrivate) + + unionOutputCols := c.OutputCols(union).ToList() + + if len(unionOutputCols) != len(outerPrivate.RightCols) { + panic(errors.AssertionFailedf("Union output column count mismatch: got %d, expected %d", + len(unionOutputCols), len(outerPrivate.RightCols))) + } + + needsProject := false + for i := range unionOutputCols { + if unionOutputCols[i] != outerPrivate.RightCols[i] { + needsProject = true + break + } + } + + var unionForExcept memo.RelExpr + if needsProject { + projections := make(memo.ProjectionsExpr, len(unionOutputCols)) + for i := range unionOutputCols { + projections[i] = c.f.ConstructProjectionsItem( + c.f.ConstructVariable(unionOutputCols[i]), + outerPrivate.RightCols[i], + ) + } + unionForExcept = c.f.ConstructProject(union, projections, opt.ColSet{}) + } else { + unionForExcept = union + } + + exceptPrivate := c.MakeSetPrivate( + innerPrivate.LeftCols, + outerPrivate.RightCols, + outerPrivate.OutCols, + ) + + return c.f.ConstructExcept(leftLeft, unionForExcept, exceptPrivate) +} + +func (c *CustomFuncs) MakeSetPrivate( + leftCols, rightCols, outCols opt.ColList, +) *memo.SetPrivate { + if len(leftCols) != len(rightCols) || len(leftCols) != len(outCols) { + panic(errors.AssertionFailedf( + "invalid SetPrivate: leftCols, rightCols, and outCols must have same length", + )) + } + + leftColsCopy := make(opt.ColList, len(leftCols)) + copy(leftColsCopy, leftCols) + + rightColsCopy := make(opt.ColList, len(rightCols)) + copy(rightColsCopy, rightCols) + + outColsCopy := make(opt.ColList, len(outCols)) + copy(outColsCopy, outCols) + + return &memo.SetPrivate{ + LeftCols: leftColsCopy, + RightCols: rightColsCopy, + OutCols: outColsCopy, + } +} + diff --git a/.envrc b/src/main/java/org/qed/Backends/Datafusion/.envrc similarity index 100% rename from .envrc rename to src/main/java/org/qed/Backends/Datafusion/.envrc diff --git a/Cargo.lock b/src/main/java/org/qed/Backends/Datafusion/Cargo.lock similarity index 100% rename from Cargo.lock rename to src/main/java/org/qed/Backends/Datafusion/Cargo.lock diff --git a/Cargo.toml b/src/main/java/org/qed/Backends/Datafusion/Cargo.toml similarity index 100% rename from Cargo.toml rename to src/main/java/org/qed/Backends/Datafusion/Cargo.toml diff --git a/src/main/java/org/qed/Backends/Datafusion/README.md b/src/main/java/org/qed/Backends/Datafusion/README.md new file mode 100644 index 0000000..0462472 --- /dev/null +++ b/src/main/java/org/qed/Backends/Datafusion/README.md @@ -0,0 +1,294 @@ +# RuleScript + +A Rust DSL for building database query rewrite rules with uninterpreted symbols. RuleScript provides a minimal, pragmatic API that wraps DataFusion's native query planning while enabling rule verification and code generation. + +## What It Does + +RuleScript lets you express query optimizer rewrite rules using abstract patterns with uninterpreted symbols: + +```rust +// Pattern: source.filter(P).filter(Q) → source.filter(P AND Q) +crate::rule! { + FilterMergeRule { + schemas: { + source: (col: T), + }, + functions: { + P(T) -> Bool, + Q(T) -> Bool, + }, + from: { + let inner = crate::filter!(source, P(col)); + crate::filter!(inner, Q(col)) + }, + to: crate::filter!(source, P(col) && Q(col)), + } +} +``` + +The `P` and `Q` are uninterpreted predicates - they can represent ANY boolean expression. This means one rule definition covers infinite concrete cases. + +## Current State + +### Implemented Rules (22 total) + +**Filter Rules:** +- FilterMergeRule - Merge consecutive filters +- FilterProjectTransposeRule - Push filter below projection +- FilterAggregateTransposeRule - Push filter predicates on GROUP BY columns below aggregate +- FilterIntoJoinRule - Merge filter into join condition +- FilterReduceTrueRule - Remove filter with true predicate +- FilterReduceFalseRule - Replace filter with false predicate with empty relation + +**Project Rules:** +- ProjectMergeRule - Merge consecutive projections +- ProjectRemoveRule - Remove identity projections + +**Join Rules:** +- JoinCommuteRule - Swap join inputs +- JoinLeftConditionPushRule - Push left-table predicates down as filter on left input +- JoinRightConditionPushRule - Push right-table predicates down as filter on right input +- JoinExtractFilterRule - Extract join condition as filter above join +- JoinLeftProjectTransposeRule - Pull projection from left join input up +- JoinRightProjectTransposeRule - Pull projection from right join input up +- JoinAssociateRule - Restructure nested joins using associativity + +**Semi-Join Rules:** +- LeftSemiJoinFilterTransposeRule - Pull filter above left semi-join +- RightSemiJoinFilterTransposeRule - Pull filter above right semi-join + +**Prune Empty Rules:** +- PruneEmptyFilterRule - Remove filter over empty relation +- PruneEmptyProjectRule - Remove projection over empty relation +- PruneEmptyUnionLeftRule - Simplify union with empty left input +- PruneEmptyUnionRightRule - Simplify union with empty right input +- PruneEmptyUnionBothRule - Simplify union with both inputs empty + +All rules have comprehensive tests (79 unit tests + 21 doc tests, all passing). + +See `src/rule/impls/README.md` for detailed rule documentation. + +### Core Features + +- **Pattern Matching**: Full support for Filter, Project, Join, and user-defined operators +- **Predicate Decomposition**: Automatic splitting of conjunctive predicates based on column dependencies +- **Function Composition**: Support for nested function applications (e.g., `f(g(x))`) +- **Alias Handling**: Transparent matching through alias wrappers +- **Column Abstraction**: Smart column pattern matching that works with field partitions +- **User-Defined Operators**: Extensibility for custom logical operators with pattern matching and QED verification support +- **QED Export**: Serialization to QED format for rule verification, including EXISTS subqueries with outer column references +- **DataFusion Integration**: RuleWrapper adapter for seamless optimizer integration + +### Architecture + +``` +src/ + ast/ + opaque.rs - Abstract types, fields, schemas + relational.rs - Logical plan patterns (Source, Filter, Project, Join) + pattern.rs - Pattern functions (ScalarPattern, AggregatePattern) + extension.rs - User-defined operator support (UserDefinedLogicalOperator trait) + source.rs - Source node implementation + matcher/ + mod.rs - PatternMatcher trait and error types + default.rs - DefaultMatcher with full pattern matching logic + rule/ + mod.rs - Rule traits (RewriteRule, ApplicableRule) + test.rs - Test utilities (table helpers) + impls/ - Concrete rule implementations + verifier/ + mod.rs - Verifier trait for rule verification + qed.rs - QED format serialization with subquery support + lib.rs - Public API exports +examples/ + optimizer_repl/ - Interactive demo with all rules + user_defined_left_semi_join.rs - Example of user-defined operator with EXISTS semantics +``` + +## Quick Start + +### Define a Rule + +```rust +crate::rule! { + MyRule { + schemas: { + source: (x: T), + }, + functions: { + P(T) -> Bool, + }, + from: crate::filter!(source, P(x)), + to: source, // Remove the filter + } +} +``` + +### Apply a Rule + +```rust +use rulescript::rule::{ApplicableRule, impls::FilterMergeRule}; + +let rule = FilterMergeRule; +let optimized_plan = rule.try_apply(&concrete_plan)?; +``` + +### Integrate with DataFusion + +```rust +use rulescript::rule::RuleWrapper; +use datafusion::optimizer::Optimizer; + +let optimizer_rule = RuleWrapper::new(FilterMergeRule); +optimizer.add_rule(Arc::new(optimizer_rule)); +``` + +## Run Interactive Demo + +```bash +cargo run --example optimizer_repl +``` + +Interactive demonstration with optimization rules: +- Choose which rules to apply +- See before/after query plans +- Real SQL parsing with DataFusion + +See `examples/README.md` for detailed usage. + +## Run Tests + +```bash +# All tests +cargo test + +# Specific rule +cargo test filter_merge + +# With output +cargo test -- --nocapture + +# Clippy checks +cargo clippy --all-targets +``` + +## Macro Reference + +### rule! - Define Complete Rules + +```rust +crate::rule! { + RuleName { + schemas: { + input_name: (field: Type), + other_input: (x: T1, y: T2), + }, + functions: { + FuncName(InputType) -> OutputType, + Predicate(T1, T2) -> Bool, + }, + from: { /* pattern to match */ }, + to: { /* replacement pattern */ }, + } +} +``` + +### Plan Construction Macros + +```rust +// Filter +crate::filter!(source, predicate) + +// Project +crate::project!(source, [expr1, expr2]) +crate::project!(source, [expr as alias]) + +// Join +crate::join!(left, right, Inner, condition) +crate::join!(left, right, Left, condition) +``` + +## Theoretical Foundation + +Based on the paper "RuleScript: A DSL for Query Optimizer Rules" which addresses the challenge of correctly implementing hundreds of rewrite rules in modern optimizers. + +**Key Concepts:** +- **Uninterpreted Symbols**: Abstract types/functions represent families of concrete queries +- **Pattern Matching**: Declarative patterns with automatic instantiation +- **Verification**: Rules can be verified via QED solver (future integration) + +**Why This Approach:** + +Modern query optimizers suffer from: +1. Error-prone manual implementation (100-200+ rules per optimizer) +2. Difficult to verify correctness +3. Redundant code across similar rules + +RuleScript solves this by: +1. One rule definition → many concrete applications +2. Automated verification possible +3. Declarative patterns reduce implementation complexity + +## Key Design Decisions + +**Pattern Matching:** +- Column patterns match based on field partitions (one pattern can match multiple columns) +- Predicates decompose automatically based on column dependencies +- Function composition works through context mapping + +**Implementation:** +- Minimal abstraction over DataFusion's native types +- No async/tokio in tests (fast, simple tests) +- Smart defaults (all types map to Binary for uniformity) +- Functions are UDFs that error on execution (pattern-only) + +**Rule Application:** +- DefaultMatcher manages three binding types: fields, functions, sources +- Context-preserving instantiation (bindings stay in their context) +- Recursive plan transformation with captured bindings + +## External Resources + +The project references two external directories not tracked in git: + +- `parser/` - Java implementation with QED-verified rules +- `calcite/` - Apache Calcite source for rule reference + +See `PARSER_AND_CALCITE_NOTES.txt` for details on these directories. + +## Future Work + +**Near-term:** +- More complex join rules (with 4-predicate decomposition) +- Rule families with meta-variables +- Additional user-defined operator examples + +**Long-term:** +- SMT solver integration for automated verification +- Code generation adapters for different engines +- Performance optimizations for pattern matching + +## Known Limitations + +**Current Implementation:** +- Join rules only support INNER joins (OUTER joins require IS NOT NULL predicates) +- Some complex expression types not yet handled (SIMILAR TO with escape) +- No optimization for pattern matching efficiency + +**Design Constraints:** +- All abstract symbols must bind to at least one match (no optional predicates) +- Single binding per symbol per rule application +- Strict column validation (all references must exist in context) + +## Dependencies + +- datafusion - Query planning framework +- thiserror - Error handling macros + +## Status + +Active development. Core pattern matching complete. User-defined operator support implemented. QED export working with subquery support. API stabilizing. + +**Test Status**: 79 unit tests + 21 doc tests passing + +The project emphasizes correctness and extensibility. Pattern matching, instantiation, and user-defined operators are fully implemented with 22 working rules demonstrating the approach works with real DataFusion plans. QED serialization enables rule verification including complex cases like EXISTS subqueries with outer column references. diff --git a/examples/README.md b/src/main/java/org/qed/Backends/Datafusion/examples/README.md similarity index 100% rename from examples/README.md rename to src/main/java/org/qed/Backends/Datafusion/examples/README.md diff --git a/examples/export_rules_to_qed.rs b/src/main/java/org/qed/Backends/Datafusion/examples/export_rules_to_qed.rs similarity index 100% rename from examples/export_rules_to_qed.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/export_rules_to_qed.rs diff --git a/examples/optimizer.rs b/src/main/java/org/qed/Backends/Datafusion/examples/optimizer.rs similarity index 100% rename from examples/optimizer.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/optimizer.rs diff --git a/examples/optimizer_repl/mod.rs b/src/main/java/org/qed/Backends/Datafusion/examples/optimizer_repl/mod.rs similarity index 100% rename from examples/optimizer_repl/mod.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/optimizer_repl/mod.rs diff --git a/examples/optimizer_repl/tables.rs b/src/main/java/org/qed/Backends/Datafusion/examples/optimizer_repl/tables.rs similarity index 100% rename from examples/optimizer_repl/tables.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/optimizer_repl/tables.rs diff --git a/examples/optimizer_repl/wrappers.rs b/src/main/java/org/qed/Backends/Datafusion/examples/optimizer_repl/wrappers.rs similarity index 100% rename from examples/optimizer_repl/wrappers.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/optimizer_repl/wrappers.rs diff --git a/examples/tpch/mod.rs b/src/main/java/org/qed/Backends/Datafusion/examples/tpch/mod.rs similarity index 100% rename from examples/tpch/mod.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/tpch/mod.rs diff --git a/examples/tpch/queries.rs b/src/main/java/org/qed/Backends/Datafusion/examples/tpch/queries.rs similarity index 100% rename from examples/tpch/queries.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/tpch/queries.rs diff --git a/examples/tpch/schema.rs b/src/main/java/org/qed/Backends/Datafusion/examples/tpch/schema.rs similarity index 100% rename from examples/tpch/schema.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/tpch/schema.rs diff --git a/examples/tpch_optimize.rs b/src/main/java/org/qed/Backends/Datafusion/examples/tpch_optimize.rs similarity index 100% rename from examples/tpch_optimize.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/tpch_optimize.rs diff --git a/examples/user_defined_left_semi_join.rs b/src/main/java/org/qed/Backends/Datafusion/examples/user_defined_left_semi_join.rs similarity index 100% rename from examples/user_defined_left_semi_join.rs rename to src/main/java/org/qed/Backends/Datafusion/examples/user_defined_left_semi_join.rs diff --git a/src/main/java/org/qed/Backends/Datafusion/flake.lock b/src/main/java/org/qed/Backends/Datafusion/flake.lock new file mode 100644 index 0000000..fbbf408 --- /dev/null +++ b/src/main/java/org/qed/Backends/Datafusion/flake.lock @@ -0,0 +1,348 @@ +{ + "nodes": { + "cachix": { + "inputs": { + "devenv": [ + "devenv" + ], + "flake-compat": [ + "devenv", + "flake-compat" + ], + "git-hooks": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760971495, + "narHash": "sha256-IwnNtbNVrlZIHh7h4Wz6VP0Furxg9Hh0ycighvL5cZc=", + "owner": "cachix", + "repo": "cachix", + "rev": "c5bfd933d1033672f51a863c47303fc0e093c2d2", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "devenv": { + "inputs": { + "cachix": "cachix", + "flake-compat": "flake-compat", + "flake-parts": "flake-parts", + "git-hooks": "git-hooks", + "nix": "nix", + "nixd": "nixd", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1767733209, + "narHash": "sha256-V1YN5JM1+/+MaiBH5puIjkjPssV8QNyFRT8EmCTurDY=", + "owner": "cachix", + "repo": "devenv", + "rev": "32a795ac142f4578aa5f6ecc8eafb79d253d99ae", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "flake": false, + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760948891, + "narHash": "sha256-TmWcdiUUaWk8J4lpjzu4gCGxWY6/Ok7mOK4fIFfBuU4=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "864599284fc7c0ba6357ed89ed5e2cd5040f0c04", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-root": { + "locked": { + "lastModified": 1723604017, + "narHash": "sha256-rBtQ8gg+Dn4Sx/s+pvjdq3CB2wQNzx9XGFq/JVGCB6k=", + "owner": "srid", + "repo": "flake-root", + "rev": "b759a56851e10cb13f6b8e5698af7b59c44be26e", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "flake-root", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "git-hooks": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760663237, + "narHash": "sha256-BflA6U4AM1bzuRMR8QqzPXqh8sWVCNDzOdsxXEguJIc=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "git-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nix": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-parts": [ + "devenv", + "flake-parts" + ], + "git-hooks-nix": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-23-11": [ + "devenv" + ], + "nixpkgs-regression": [ + "devenv" + ] + }, + "locked": { + "lastModified": 1766922625, + "narHash": "sha256-O0wExzdYqSNqbPYCQhUWeoKlDa7q6wxhuWiHolxqdl8=", + "owner": "cachix", + "repo": "nix", + "rev": "c62c4bdb6673871ae5cdc51c498df6292d5169aa", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "devenv-2.32", + "repo": "nix", + "type": "github" + } + }, + "nixd": { + "inputs": { + "flake-parts": [ + "devenv", + "flake-parts" + ], + "flake-root": "flake-root", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "treefmt-nix": "treefmt-nix" + }, + "locked": { + "lastModified": 1763964548, + "narHash": "sha256-JTRoaEWvPsVIMFJWeS4G2isPo15wqXY/otsiHPN0zww=", + "owner": "nix-community", + "repo": "nixd", + "rev": "d4bf15e56540422e2acc7bc26b20b0a0934e3f5e", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixd", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1767640445, + "narHash": "sha256-UWYqmD7JFBEDBHWYcqE6s6c77pWdcU/i+bwD6XxMb8A=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "9f0c42f8bc7151b8e7e5840fb3bd454ad850d8c5", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-python": { + "inputs": { + "flake-compat": "flake-compat_2", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1765052656, + "narHash": "sha256-DrMjrjxMttbGDoVxr/xke0ihd5GVd6fyUVsjuepEsCc=", + "owner": "cachix", + "repo": "nixpkgs-python", + "rev": "04b27dbad2e004cb237db202f21154eea3c4f89f", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "nixpkgs-python", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "nixpkgs-python": "nixpkgs-python" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "devenv", + "nixd", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1734704479, + "narHash": "sha256-MMi74+WckoyEWBRcg/oaGRvXC9BVVxDZNRMpL+72wBI=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "65712f5af67234dad91a5a4baee986a8b62dbf8f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/src/main/java/org/qed/Backends/Datafusion/flake.nix b/src/main/java/org/qed/Backends/Datafusion/flake.nix new file mode 100644 index 0000000..1a1c5b7 --- /dev/null +++ b/src/main/java/org/qed/Backends/Datafusion/flake.nix @@ -0,0 +1,50 @@ +{ + inputs = { + devenv = { + inputs.nixpkgs.follows = "nixpkgs"; + url = "github:cachix/devenv"; + }; + flake-utils.url = "github:numtide/flake-utils"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + nixpkgs-python = { + inputs.nixpkgs.follows = "nixpkgs"; + url = "github:cachix/nixpkgs-python"; + }; + }; + + outputs = inputs @ { + self, + devenv, + flake-utils, + nixpkgs, + ... + }: + flake-utils.lib.eachDefaultSystem (system: let + pkgs = import nixpkgs { + inherit system; + config.allowUnfree = true; + }; + in { + packages = { + devenv-up = self.devShells.${system}.default.config.procfileScript; + devenv-test = self.devShells.${system}.default.config.test; + }; + + devShells.default = devenv.lib.mkShell { + inherit inputs pkgs; + modules = [ + { + git-hooks.hooks.alejandra.enable = true; + languages = { + nix.enable = true; + rust.enable = true; + }; + packages = with pkgs; [ + cvc5 + opencode + ]; + } + ]; + }; + }); +} diff --git a/src/ast/empty.rs b/src/main/java/org/qed/Backends/Datafusion/src/ast/empty.rs similarity index 100% rename from src/ast/empty.rs rename to src/main/java/org/qed/Backends/Datafusion/src/ast/empty.rs diff --git a/src/ast/extension.rs b/src/main/java/org/qed/Backends/Datafusion/src/ast/extension.rs similarity index 100% rename from src/ast/extension.rs rename to src/main/java/org/qed/Backends/Datafusion/src/ast/extension.rs diff --git a/src/ast/mod.rs b/src/main/java/org/qed/Backends/Datafusion/src/ast/mod.rs similarity index 100% rename from src/ast/mod.rs rename to src/main/java/org/qed/Backends/Datafusion/src/ast/mod.rs diff --git a/src/ast/opaque.rs b/src/main/java/org/qed/Backends/Datafusion/src/ast/opaque.rs similarity index 100% rename from src/ast/opaque.rs rename to src/main/java/org/qed/Backends/Datafusion/src/ast/opaque.rs diff --git a/src/ast/pattern.rs b/src/main/java/org/qed/Backends/Datafusion/src/ast/pattern.rs similarity index 100% rename from src/ast/pattern.rs rename to src/main/java/org/qed/Backends/Datafusion/src/ast/pattern.rs diff --git a/src/ast/relational.rs b/src/main/java/org/qed/Backends/Datafusion/src/ast/relational.rs similarity index 100% rename from src/ast/relational.rs rename to src/main/java/org/qed/Backends/Datafusion/src/ast/relational.rs diff --git a/src/ast/source.rs b/src/main/java/org/qed/Backends/Datafusion/src/ast/source.rs similarity index 100% rename from src/ast/source.rs rename to src/main/java/org/qed/Backends/Datafusion/src/ast/source.rs diff --git a/src/lib.rs b/src/main/java/org/qed/Backends/Datafusion/src/lib.rs similarity index 100% rename from src/lib.rs rename to src/main/java/org/qed/Backends/Datafusion/src/lib.rs diff --git a/src/matcher/default.rs b/src/main/java/org/qed/Backends/Datafusion/src/matcher/default.rs similarity index 100% rename from src/matcher/default.rs rename to src/main/java/org/qed/Backends/Datafusion/src/matcher/default.rs diff --git a/src/matcher/mod.rs b/src/main/java/org/qed/Backends/Datafusion/src/matcher/mod.rs similarity index 100% rename from src/matcher/mod.rs rename to src/main/java/org/qed/Backends/Datafusion/src/matcher/mod.rs diff --git a/src/rule/impls/README.md b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/README.md similarity index 100% rename from src/rule/impls/README.md rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/README.md diff --git a/src/rule/impls/filter_aggregate_transpose.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_aggregate_transpose.rs similarity index 100% rename from src/rule/impls/filter_aggregate_transpose.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_aggregate_transpose.rs diff --git a/src/rule/impls/filter_into_join.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_into_join.rs similarity index 100% rename from src/rule/impls/filter_into_join.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_into_join.rs diff --git a/src/rule/impls/filter_merge.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_merge.rs similarity index 100% rename from src/rule/impls/filter_merge.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_merge.rs diff --git a/src/rule/impls/filter_project_transpose.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_project_transpose.rs similarity index 100% rename from src/rule/impls/filter_project_transpose.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_project_transpose.rs diff --git a/src/rule/impls/filter_reduce_false.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_reduce_false.rs similarity index 100% rename from src/rule/impls/filter_reduce_false.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_reduce_false.rs diff --git a/src/rule/impls/filter_reduce_true.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_reduce_true.rs similarity index 100% rename from src/rule/impls/filter_reduce_true.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/filter_reduce_true.rs diff --git a/src/rule/impls/join_associate.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_associate.rs similarity index 100% rename from src/rule/impls/join_associate.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_associate.rs diff --git a/src/rule/impls/join_commute.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_commute.rs similarity index 100% rename from src/rule/impls/join_commute.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_commute.rs diff --git a/src/rule/impls/join_condition_push.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_condition_push.rs similarity index 100% rename from src/rule/impls/join_condition_push.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_condition_push.rs diff --git a/src/rule/impls/join_extract_filter.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_extract_filter.rs similarity index 100% rename from src/rule/impls/join_extract_filter.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_extract_filter.rs diff --git a/src/rule/impls/join_project_transpose.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_project_transpose.rs similarity index 100% rename from src/rule/impls/join_project_transpose.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/join_project_transpose.rs diff --git a/src/rule/impls/mod.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/mod.rs similarity index 100% rename from src/rule/impls/mod.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/mod.rs diff --git a/src/rule/impls/project_merge.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/project_merge.rs similarity index 100% rename from src/rule/impls/project_merge.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/project_merge.rs diff --git a/src/rule/impls/project_remove.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/project_remove.rs similarity index 100% rename from src/rule/impls/project_remove.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/project_remove.rs diff --git a/src/rule/impls/prune_empty_filter.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/prune_empty_filter.rs similarity index 100% rename from src/rule/impls/prune_empty_filter.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/prune_empty_filter.rs diff --git a/src/rule/impls/prune_empty_project.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/prune_empty_project.rs similarity index 100% rename from src/rule/impls/prune_empty_project.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/prune_empty_project.rs diff --git a/src/rule/impls/prune_empty_union.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/prune_empty_union.rs similarity index 100% rename from src/rule/impls/prune_empty_union.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/prune_empty_union.rs diff --git a/src/rule/impls/semi_join_filter_transpose.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/impls/semi_join_filter_transpose.rs similarity index 100% rename from src/rule/impls/semi_join_filter_transpose.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/impls/semi_join_filter_transpose.rs diff --git a/src/rule/mod.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/mod.rs similarity index 100% rename from src/rule/mod.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/mod.rs diff --git a/src/rule/test.rs b/src/main/java/org/qed/Backends/Datafusion/src/rule/test.rs similarity index 100% rename from src/rule/test.rs rename to src/main/java/org/qed/Backends/Datafusion/src/rule/test.rs diff --git a/src/verifier/mod.rs b/src/main/java/org/qed/Backends/Datafusion/src/verifier/mod.rs similarity index 100% rename from src/verifier/mod.rs rename to src/main/java/org/qed/Backends/Datafusion/src/verifier/mod.rs diff --git a/src/verifier/qed.rs b/src/main/java/org/qed/Backends/Datafusion/src/verifier/qed.rs similarity index 100% rename from src/verifier/qed.rs rename to src/main/java/org/qed/Backends/Datafusion/src/verifier/qed.rs diff --git a/src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge1.sql b/src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge1.sql new file mode 100644 index 0000000..3579e90 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge1.sql @@ -0,0 +1,5 @@ +INSERT INTO query_rewrite.rewrite_rules + (pattern, replacement) VALUES( + 'SELECT * FROM (SELECT * FROM testdb.users WHERE id = ?) AS t0 WHERE status = ?', + 'SELECT * FROM testdb.users WHERE id = ? AND status = ?' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge2.sql b/src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge2.sql new file mode 100644 index 0000000..a48e50c --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Generated/FilterMerge2.sql @@ -0,0 +1,5 @@ +INSERT INTO query_rewrite.rewrite_rules + (pattern, replacement) VALUES( + 'SELECT * FROM (SELECT * FROM testdb.users WHERE status = ?) AS t0 WHERE id = ?', + 'SELECT * FROM testdb.users WHERE status = ? AND id = ?' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute1.sql b/src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute1.sql new file mode 100644 index 0000000..67a32bc --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute1.sql @@ -0,0 +1,5 @@ +INSERT INTO query_rewrite.rewrite_rules + (pattern, replacement) VALUES( + 'SELECT * FROM (SELECT * FROM testdb.users) AS t0 INNER JOIN (SELECT * FROM testdb.users) AS t1 ON t0.id = t1.id', + 'SELECT * FROM (SELECT * FROM testdb.users) AS t1 INNER JOIN (SELECT * FROM testdb.users) AS t0 ON t1.id = t0.id' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute2.sql b/src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute2.sql new file mode 100644 index 0000000..3805762 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Generated/JoinCommute2.sql @@ -0,0 +1,5 @@ +INSERT INTO query_rewrite.rewrite_rules + (pattern, replacement) VALUES( + 'SELECT * FROM (SELECT * FROM testdb.users) AS t0 INNER JOIN (SELECT * FROM testdb.users) AS t1 ON t0.status = t1.status', + 'SELECT * FROM (SELECT * FROM testdb.users) AS t1 INNER JOIN (SELECT * FROM testdb.users) AS t0 ON t1.status = t0.status' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge1.sql b/src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge1.sql new file mode 100644 index 0000000..4544506 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge1.sql @@ -0,0 +1,5 @@ +INSERT INTO query_rewrite.rewrite_rules + (pattern, replacement) VALUES( + 'SELECT id, status FROM (SELECT id, status FROM testdb.users) AS t0', + 'SELECT id, status FROM testdb.users' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge2.sql b/src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge2.sql new file mode 100644 index 0000000..df80c53 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Generated/ProjectMerge2.sql @@ -0,0 +1,5 @@ +INSERT INTO query_rewrite.rewrite_rules + (pattern, replacement) VALUES( + 'SELECT status, id FROM (SELECT status, id FROM testdb.users) AS t0', + 'SELECT status, id FROM testdb.users' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/MySQLGenerator.java b/src/main/java/org/qed/Backends/MySQL/MySQLGenerator.java new file mode 100644 index 0000000..de800f4 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/MySQLGenerator.java @@ -0,0 +1,155 @@ +package org.qed.Backends.MySQL; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRuleInstances.JoinCommute; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +public class MySQLGenerator { + + private int subqueryCounter = 0; + private final String tableName; + private final List columnNames; + + public MySQLGenerator(String tableName, List columnNames) { + this.tableName = tableName; + this.columnNames = columnNames; + } + + private static class FlattenedSQLParts { + String fromClause = ""; + List projections = new ArrayList<>(); + List conditions = new ArrayList<>(); + } + + public String translate(String name, RelRN before, RelRN after) { + String beforeSQL; + String afterSQL; + + if (name.equals("JoinCommute")) { + subqueryCounter = 0; + beforeSQL = transformNested(before, true, false, new AtomicInteger(0)); + subqueryCounter = 0; + afterSQL = transformNested(before, true, true, new AtomicInteger(0)); + } else { + subqueryCounter = 0; + beforeSQL = transformNested(before, true, false, new AtomicInteger(0)); + afterSQL = transformFlatten(after); + } + + return "INSERT INTO query_rewrite.rewrite_rules\n" + + " (pattern, replacement) VALUES(\n" + + " '" + beforeSQL + "',\n" + + " '" + afterSQL + "'\n" + + ");"; + } + + private String transformNested(RelRN node, boolean isRoot, boolean swapJoinSides, AtomicInteger filterIndex) { + if (node instanceof RelRN.Scan) { + return "SELECT * FROM " + tableName; + } else if (node instanceof RelRN.Project project) { + String cols = String.join(", ", columnNames); + if (project.source() instanceof RelRN.Scan) { + return "SELECT " + cols + " FROM " + tableName; + } + String innerSQL = transformNested(project.source(), false, swapJoinSides, filterIndex); + String alias = "t" + (subqueryCounter++); + return "SELECT " + cols + " FROM (" + innerSQL + ") AS " + alias; + } else if (node instanceof RelRN.Filter filter) { + String innerSQL = transformNested(filter.source(), false, swapJoinSides, filterIndex); + int currentIndex = filterIndex.getAndIncrement(); + String condition = (currentIndex < columnNames.size()) + ? columnNames.get(currentIndex) + " = ?" + : columnNames.get(0) + " = ?"; + + if (isRoot) { + return innerSQL + " WHERE " + condition; + } else { + String alias = "t" + (subqueryCounter++); + return "SELECT * FROM (" + innerSQL + " WHERE " + condition + ") AS " + alias; + } + } else if (node instanceof RelRN.Join join) { + String leftAlias = "t0"; + String rightAlias = "t1"; + + RelRN firstNode = swapJoinSides ? join.right() : join.left(); + String firstAlias = swapJoinSides ? rightAlias : leftAlias; + RelRN secondNode = swapJoinSides ? join.left() : join.right(); + String secondAlias = swapJoinSides ? leftAlias : rightAlias; + + String firstSQL = "(" + transformNested(firstNode, false, swapJoinSides, filterIndex) + ")"; + String secondSQL = "(" + transformNested(secondNode, false, swapJoinSides, filterIndex) + ")"; + + String joinCond = renderJoinCondition(join.cond(), leftAlias, rightAlias, swapJoinSides); + + String joinExpr = + firstSQL + " AS " + firstAlias + + " " + join.ty().semantics().name() + " JOIN " + + secondSQL + " AS " + secondAlias + + " ON " + joinCond; + + if (isRoot) { + return "SELECT * FROM " + joinExpr; + } else { + String alias = "t" + (subqueryCounter++); + return "SELECT * FROM (" + joinExpr + ") AS " + alias; + } + + } else if (node instanceof JoinCommute.ProjectionRelRN projRN) { + return transformNested(projRN.source(), isRoot, swapJoinSides, filterIndex); + } else { + throw new UnsupportedOperationException("Unsupported RelRN: " + node); + } + } + + private String renderJoinCondition(RexRN cond, String leftAlias, String rightAlias, boolean swap) { + if (cond instanceof RexRN.Pred p) { + if (p.sources().get(0) instanceof RexRN.JoinField jf) { + String colName = columnNames.get(jf.ordinal()); + String first = swap ? rightAlias : leftAlias; + String second = swap ? leftAlias : rightAlias; + return first + "." + colName + " = " + second + "." + colName; + } + } + throw new UnsupportedOperationException("Unsupported join condition: " + cond); + } + + public String transformFlatten(RelRN node) { + FlattenedSQLParts parts = new FlattenedSQLParts(); + collectFlattenedParts(node, parts); + String selectClause = parts.projections.isEmpty() ? "SELECT *" : "SELECT " + String.join(", ", parts.projections); + String whereClause = parts.conditions.isEmpty() ? "" : " WHERE " + String.join(" AND ", parts.conditions); + return selectClause + " FROM " + parts.fromClause + whereClause; + } + + private void collectFlattenedParts(RelRN node, FlattenedSQLParts parts) { + switch (node) { + case RelRN.Scan scan -> parts.fromClause = tableName; + case RelRN.Project project -> { + collectFlattenedParts(project.source(), parts); + parts.projections.addAll(columnNames); + } + case RelRN.Filter filter -> { + collectFlattenedParts(filter.source(), parts); + collectPredConditions(filter.cond(), parts.conditions); + } + default -> throw new UnsupportedOperationException("Unsupported RelRN for flatten: " + node); + } + } + + private void collectPredConditions(RexRN pred, List conditions) { + if (pred instanceof RexRN.Pred) { + int currentConditions = conditions.size(); + if (currentConditions < columnNames.size()) { + conditions.add(columnNames.get(currentConditions) + " = ?"); + } + } else if (pred instanceof RexRN.And and) { + for (RexRN child : and.sources()) { + collectPredConditions(child, conditions); + } + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge1Test.sql b/src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge1Test.sql new file mode 100644 index 0000000..884f26c --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge1Test.sql @@ -0,0 +1,2 @@ +SELECT * FROM (SELECT * FROM testdb.users WHERE id = 1) AS t0 +WHERE status = 'active'; \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge2Test.sql b/src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge2Test.sql new file mode 100644 index 0000000..62c05e2 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/FilterMerge2Test.sql @@ -0,0 +1,2 @@ +SELECT * FROM (SELECT * FROM testdb.users WHERE status = 'active') AS t0 +WHERE id = 1; \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute1Test.sql b/src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute1Test.sql new file mode 100644 index 0000000..00a2a85 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute1Test.sql @@ -0,0 +1,3 @@ +SELECT * FROM (SELECT * FROM testdb.users) AS t0 +INNER JOIN (SELECT * FROM testdb.users) AS t1 +ON t0.id = t1.id; diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute2Test.sql b/src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute2Test.sql new file mode 100644 index 0000000..c213c1d --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/JoinCommute2Test.sql @@ -0,0 +1,3 @@ +SELECT * FROM (SELECT * FROM testdb.users) AS t0 +INNER JOIN (SELECT * FROM testdb.users) AS t1 +ON t0.status = t1.status; diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/MySQLTester.java b/src/main/java/org/qed/Backends/MySQL/Tests/MySQLTester.java new file mode 100644 index 0000000..9b6b62c --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/MySQLTester.java @@ -0,0 +1,45 @@ +package org.qed.Backends.MySQL.Tests; + +import org.qed.*; +import org.qed.Backends.MySQL.MySQLGenerator; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +public class MySQLTester { + + public static String genPath = "src/main/java/org/qed/Backends/MySQL/Generated"; + + public static String tableName = "testdb.users"; + public static List columnNames = List.of("id", "status"); + + public static void main(String[] args) { + var filterRule = new org.qed.RRuleInstances.FilterMerge(); + new MySQLTester().serializeWithNumericSuffix(filterRule, genPath); + + var projectRule = new org.qed.RRuleInstances.ProjectMerge(); + new MySQLTester().serializeWithNumericSuffix(projectRule, genPath); + + var joinCommute = new org.qed.RRuleInstances.JoinCommute(); + new MySQLTester().serializeWithNumericSuffix(joinCommute, genPath); + } + + public void serializeWithNumericSuffix(RRule rule, String path) { + serialize(rule, path, tableName, columnNames, 1); + serialize(rule, path, tableName, List.of(columnNames.get(1), columnNames.get(0)), 2); + } + + private void serialize(RRule rule, String path, String tableName, List colNames, int fileIndex) { + var generator = new MySQLGenerator(tableName, colNames); + var codeGen = generator.translate(rule.name(), rule.before(), rule.after()); + try { + Files.createDirectories(Path.of(path)); + String fileName = rule.name() + fileIndex + ".sql"; + Files.write(Path.of(path, fileName), codeGen.getBytes()); + } catch (IOException ioe) { + System.err.println(ioe.getMessage()); + } + } +} diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge1Test.sql b/src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge1Test.sql new file mode 100644 index 0000000..4800f37 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge1Test.sql @@ -0,0 +1 @@ +SELECT id, status FROM (SELECT id, status FROM testdb.users) AS t0; \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge2Test.sql b/src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge2Test.sql new file mode 100644 index 0000000..57bcffb --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/ProjectMerge2Test.sql @@ -0,0 +1 @@ +SELECT status, id FROM (SELECT status, id FROM testdb.users) AS t0; \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/MySQL/Tests/script-mysql.py b/src/main/java/org/qed/Backends/MySQL/Tests/script-mysql.py new file mode 100644 index 0000000..136e5ce --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/Tests/script-mysql.py @@ -0,0 +1,80 @@ +import mysql.connector +from pathlib import Path + +MYSQL_USER = "root" +MYSQL_PASSWORD = "wkaiz" +MYSQL_DATABASE = "query_rewrite" + +RULE_DIR = Path("../Generated") +TEST_DIR = Path(".") + +conn = mysql.connector.connect( + host="localhost", + user=MYSQL_USER, + password=MYSQL_PASSWORD, + database=MYSQL_DATABASE +) +cursor = conn.cursor() + +for rule_file in RULE_DIR.glob("*.sql"): + test_file = TEST_DIR / f"{rule_file.stem}test.sql" + + if not test_file.exists(): + print(f"⚠️ No matching test file found for {rule_file.name}, skipping.") + continue + + print(f"\n=== Running rule {rule_file.name} with test {test_file.name} ===") + + with rule_file.open("r", encoding="utf-8") as f: + sql_commands = f.read() + + for cmd in sql_commands.split(";"): + cmd = cmd.strip() + if cmd: + try: + cursor.execute(cmd + ";") + except mysql.connector.Error as e: + print(f"❌ Error executing command in {rule_file.name}: {e}") + continue + + conn.commit() + print(f"{rule_file} executed successfully.") + cursor.execute(""" + DELETE FROM rewrite_rules + WHERE id < ( + SELECT max_id FROM (SELECT MAX(id) AS max_id FROM rewrite_rules) AS t + ); + """) + conn.commit() + print("Deleted all rules except the last one.") + + cursor.execute("CALL flush_rewrite_rules();") + conn.commit() + print("Flushed rewrite rules.") + + cursor.execute("SELECT * FROM rewrite_rules;") + print("Current rules in table:") + for row in cursor.fetchall(): + print(row) + + with test_file.open("r", encoding="utf-8") as f: + test_query = f.read().strip() + + try: + cursor.execute(test_query) + results = cursor.fetchall() + print("\nTest query results:") + for row in results: + print(row) + except mysql.connector.Error as e: + print(f"❌ Error running test query {test_file.name}: {e}") + continue + + cursor.execute("SHOW WARNINGS;") + warnings = cursor.fetchall() + print("\nWarnings (should indicate rewrite):") + for w in warnings: + print(w) + +cursor.close() +conn.close() diff --git a/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterMerge.sql b/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterMerge.sql new file mode 100644 index 0000000..ab08153 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterMerge.sql @@ -0,0 +1,6 @@ +INSERT INTO mysql_query_rules (rule_id, active, match_pattern, replace_pattern) +VALUES ( + 10, 1, + '^SELECT \* FROM \(SELECT \* FROM (.*) WHERE (.*) = (.*)\) AS (.*) WHERE (.*) = (.*)', + 'SELECT * FROM \1 WHERE \2 = \3 AND \5 = \6' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceFalse.sql b/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceFalse.sql new file mode 100644 index 0000000..0ae9c6c --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceFalse.sql @@ -0,0 +1,6 @@ +INSERT INTO mysql_query_rules (rule_id, active, match_pattern, replace_pattern) +VALUES ( + 40, 1, + '^SELECT (.*) FROM (.*) WHERE FALSE', + 'SELECT \1 FROM \2 LIMIT 0' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceTrue.sql b/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceTrue.sql new file mode 100644 index 0000000..4ca4ae4 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Generated/FilterReduceTrue.sql @@ -0,0 +1,6 @@ +INSERT INTO mysql_query_rules (rule_id, active, match_pattern, replace_pattern) +VALUES ( + 50, 1, + '^SELECT (.*) FROM (.*) WHERE TRUE', + 'SELECT \1 FROM \2' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/Generated/JoinCommute.sql b/src/main/java/org/qed/Backends/ProxySQL/Generated/JoinCommute.sql new file mode 100644 index 0000000..0d91e1f --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Generated/JoinCommute.sql @@ -0,0 +1,6 @@ +INSERT INTO mysql_query_rules (rule_id, active, match_pattern, replace_pattern) +VALUES ( + 30, 1, + '^SELECT \* FROM (.*) AS (.*?) INNER JOIN (.*) AS (.*?) ON (.*?)\.(.*?) = (.*?)\.(.*)', + 'SELECT * FROM \3 AS \4 INNER JOIN \1 AS \2 ON \7.\8 = \5.\6' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/Generated/ProjectMerge.sql b/src/main/java/org/qed/Backends/ProxySQL/Generated/ProjectMerge.sql new file mode 100644 index 0000000..8af9ae2 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Generated/ProjectMerge.sql @@ -0,0 +1,6 @@ +INSERT INTO mysql_query_rules (rule_id, active, match_pattern, replace_pattern) +VALUES ( + 20, 1, + '^SELECT (.*) FROM \(SELECT (.*) FROM (.*)\) AS (.*)', + 'SELECT \1 FROM \3' +); \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/ProxySQLGenerator.java b/src/main/java/org/qed/Backends/ProxySQL/ProxySQLGenerator.java new file mode 100644 index 0000000..50f6983 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/ProxySQLGenerator.java @@ -0,0 +1,135 @@ +package org.qed.Generated; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRuleInstances.JoinCommute; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +public class ProxySQLGenerator { + + private final Map predicateToGroupIndex = new HashMap<>(); + private final AtomicInteger groupCounter = new AtomicInteger(1); + private boolean isReduceTrueRule = false; + + public String translate(int ruleId, String name, RelRN before, RelRN after) { + predicateToGroupIndex.clear(); + groupCounter.set(1); + this.isReduceTrueRule = false; + + String matchPattern = generateMatchPattern(before); + String replacePattern = generateReplacePattern(after); + + return String.format( + """ + INSERT INTO mysql_query_rules (rule_id, active, match_pattern, replace_pattern) + VALUES ( + %d, 1, + '^%s', + '%s' + );""", + ruleId, matchPattern, replacePattern + ); + } + + private String generateMatchPattern(RelRN node) { + return switch (node) { + case RelRN.Filter filter -> { + if (filter.cond() instanceof RexRN.True) { + this.isReduceTrueRule = true; + groupCounter.addAndGet(2); + yield "SELECT (.*) FROM (.*) WHERE TRUE"; + } + + if (filter.cond() instanceof RexRN.False) { + groupCounter.addAndGet(2); + yield "SELECT (.*) FROM (.*) WHERE FALSE"; + } + + String sourcePattern = generateMatchPattern(filter.source()); + String conditionRegex = "(.*) = (.*)"; + if (filter.source() instanceof RelRN.Scan) { + int conditionGroupStart = groupCounter.get(); + groupCounter.addAndGet(2); + predicateToGroupIndex.put(filter.cond(), conditionGroupStart); + yield sourcePattern + " WHERE " + conditionRegex; + } else { + groupCounter.getAndIncrement(); + int conditionGroupStart = groupCounter.get(); + groupCounter.addAndGet(2); + predicateToGroupIndex.put(filter.cond(), conditionGroupStart); + yield String.format("SELECT \\* FROM \\(%s\\) AS (.*) WHERE %s", sourcePattern, conditionRegex); + } + } + case RelRN.Project project -> { + if (project.source() instanceof RelRN.Project innerProject && innerProject.source() instanceof RelRN.Scan) { + yield "SELECT (.*) FROM \\(SELECT (.*) FROM (.*)\\) AS (.*)"; + } + throw new UnsupportedOperationException("This generator only supports the specific Project(Project(Scan)) pattern."); + } + case RelRN.Join join -> { + if (join.left() instanceof RelRN.Scan && join.right() instanceof RelRN.Scan) { + yield "SELECT \\* FROM (.*) AS (.*?) INNER JOIN (.*) AS (.*?) ON (.*?)\\.(.*?) = (.*?)\\.(.*)"; + } + throw new UnsupportedOperationException("This generator only supports simple Scan-Join-Scan patterns."); + } + case RelRN.Scan scan -> { + groupCounter.getAndIncrement(); + yield "SELECT \\* FROM (.*)"; + } + default -> throw new UnsupportedOperationException("Unsupported RelRN for match pattern: " + node.getClass().getSimpleName()); + }; + } + + private String generateReplacePattern(RelRN node) { + return switch (node) { + case RelRN.Empty empty -> { + yield "SELECT \\1 FROM \\2 LIMIT 0"; + } + case JoinCommute.ProjectionRelRN proj -> { + if (proj.source() instanceof RelRN.Join) { + yield "SELECT * FROM \\3 AS \\4 INNER JOIN \\1 AS \\2 ON \\7.\\8 = \\5.\\6"; + } + throw new UnsupportedOperationException("Unsupported 'after' pattern for JoinCommute."); + } + case RelRN.Filter filter -> { + String fromClause = generateReplacePattern(filter.source()); + String whereClause = buildWhereClause(filter.cond()); + yield String.format("%s WHERE %s", fromClause, whereClause); + } + case RelRN.Project project -> { + if (project.source() instanceof RelRN.Scan) { + yield "SELECT \\1 FROM \\3"; + } + throw new UnsupportedOperationException("Unsupported 'after' pattern for ProjectMerge."); + } + case RelRN.Scan scan -> { + if (this.isReduceTrueRule) { + yield "SELECT \\1 FROM \\2"; + } else { + yield "SELECT * FROM \\1"; + } + } + default -> throw new UnsupportedOperationException("Unsupported RelRN for replace pattern: " + node.getClass().getSimpleName()); + }; + } + + private String buildWhereClause(RexRN condition) { + return switch (condition) { + case RexRN.And andNode -> andNode.sources().stream() + .map(this::buildWhereClause) + .collect(Collectors.joining(" AND ")); + case RexRN.Pred pred -> { + Integer groupIndex = predicateToGroupIndex.get(pred); + if (groupIndex == null) { + throw new IllegalStateException("Predicate from 'after' tree not found in 'before' tree: " + pred); + } + yield String.format("\\%d = \\%d", groupIndex, groupIndex + 1); + } + default -> throw new UnsupportedOperationException("Unsupported RexRN for WHERE clause: " + condition.getClass().getSimpleName()); + }; + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterMergeTest.sql b/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterMergeTest.sql new file mode 100644 index 0000000..a418bc4 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterMergeTest.sql @@ -0,0 +1 @@ +SELECT * FROM (SELECT * FROM testdb.users WHERE status = 'active') AS t0 WHERE id = 1; \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceFalseTest.sql b/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceFalseTest.sql new file mode 100644 index 0000000..d100168 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceFalseTest.sql @@ -0,0 +1 @@ +SELECT name FROM testdb.users WHERE FALSE; diff --git a/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceTrueTest.sql b/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceTrueTest.sql new file mode 100644 index 0000000..a6a709f --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Tests/FilterReduceTrueTest.sql @@ -0,0 +1 @@ +SELECT * FROM testdb.users WHERE TRUE; diff --git a/src/main/java/org/qed/Backends/ProxySQL/Tests/JoinCommuteTest.sql b/src/main/java/org/qed/Backends/ProxySQL/Tests/JoinCommuteTest.sql new file mode 100644 index 0000000..a881320 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Tests/JoinCommuteTest.sql @@ -0,0 +1 @@ +SELECT * FROM testdb.users AS u INNER JOIN testdb.orders AS o ON u.id = o.user_id; \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/Tests/ProjectMergeTest.sql b/src/main/java/org/qed/Backends/ProxySQL/Tests/ProjectMergeTest.sql new file mode 100644 index 0000000..bcefcd6 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Tests/ProjectMergeTest.sql @@ -0,0 +1 @@ +SELECT name FROM (SELECT name FROM testdb.users) AS t0; diff --git a/src/main/java/org/qed/Backends/ProxySQL/Tests/script-proxysql.sh b/src/main/java/org/qed/Backends/ProxySQL/Tests/script-proxysql.sh new file mode 100644 index 0000000..6ab1db8 --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/Tests/script-proxysql.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +PROXYSQL_USER="admin" +PROXYSQL_PASS="admin" +PROXYSQL_HOST="127.0.0.1" +PROXYSQL_PORT="6032" + +MYSQL_USER="root" +MYSQL_PASS="wkaiz" +MYSQL_HOST="127.0.0.1" +MYSQL_PORT="6033" + +run_proxysql() { + local sql="$1" + docker exec -i proxysql mysql -u "$PROXYSQL_USER" -p"$PROXYSQL_PASS" -h "$PROXYSQL_HOST" -P"$PROXYSQL_PORT" -e "$sql" +} + +run_mysql() { + local sql="$1" + echo -e "\n➡️ Running MySQL command:\n$sql\n" + mysql -u "$MYSQL_USER" -p"$MYSQL_PASS" -h "$MYSQL_HOST" -P"$MYSQL_PORT" -e "$sql" + echo -e "\n----------------------------------------\n" +} + +for rule_file in proxysql/*.sql; do + base_name=$(basename "$rule_file" .sql) + test_file="Tests-ProxySQL/${base_name}Test.sql" + + echo -e "\n==============================" + echo "📄 Processing rule file: $rule_file" + echo "Corresponding test file: $test_file" + echo "==============================\n" + + sql_content=$(cat <(echo "USE main;") "$rule_file") + echo -e "➡️ Loading ProxySQL rule:\n$sql_content\n" + run_proxysql "$sql_content" + + run_proxysql "LOAD MYSQL QUERY RULES TO RUNTIME;" + run_proxysql "SAVE MYSQL QUERY RULES TO DISK;" + + echo -e "\n📊 ProxySQL stats after loading rule:" + run_proxysql "SELECT * FROM stats.stats_mysql_query_rules ORDER BY hits DESC;" + echo -e "\n----------------------------------------\n" + + test_sql=$(<"$test_file") + run_mysql "$test_sql" + + sleep 2 + echo -e "\n📊 ProxySQL stats after running test:" + run_proxysql "SELECT * FROM stats.stats_mysql_query_rules ORDER BY hits DESC;" + echo -e "\n----------------------------------------\n" + + echo "🧹 Cleaning up ProxySQL rules..." + run_proxysql "DELETE FROM mysql_query_rules;" + run_proxysql "LOAD MYSQL QUERY RULES TO RUNTIME;" + run_proxysql "SAVE MYSQL QUERY RULES TO DISK;" + echo -e "\n========================================\n" +done diff --git a/src/main/java/org/qed/CodeGenerator.java b/src/main/java/org/qed/CodeGenerator.java new file mode 100644 index 0000000..4700fb1 --- /dev/null +++ b/src/main/java/org/qed/CodeGenerator.java @@ -0,0 +1,273 @@ +package org.qed; + +public interface CodeGenerator { + + default String unimplemented(String context, Object object) { + return "<--" + context + object.getClass().getName() + "-->"; + } + + default E unimplementedOnMatch(E env, Object object) { + System.err.println(unimplemented("Unspecified onMatch codegen: ", object)); + return env; + } + + default E unimplementedTransform(E env, Object object) { + System.err.println(unimplemented("Unspecified transform codegen: ", object)); + return env; + } + + E preMatch(String rulename); + + default E onMatch(E env, RelRN pattern) { + return switch (pattern) { + case RelRN.Scan scan -> onMatchScan(env, scan); + case RelRN.Filter filter -> onMatchFilter(env, filter); + case RelRN.Project project -> onMatchProject(env, project); + case RelRN.Join join -> onMatchJoin(env, join); + case RelRN.JoinWithSeparateConds join -> onMatchJoinWithSeparateConds(env, join); + case RelRN.Union union -> onMatchUnion(env, union); + case RelRN.Intersect intersect -> onMatchIntersect(env, intersect); + case RelRN.Minus minus -> onMatchMinus(env, minus); + case RelRN.Empty empty -> onMatchEmpty(env, empty); + case RelRN.Aggregate aggregate -> onMatchAggregate(env, aggregate); + default -> onMatchCustom(env, pattern); + }; + } + + default E onMatch(E env, RexRN pattern) { + return switch (pattern) { + case RexRN.Field field -> onMatchField(env, field); + case RexRN.JoinField joinField -> onMatchJoinField(env, joinField); + case RexRN.Proj proj -> onMatchProj(env, proj); + case RexRN.Pred pred -> onMatchPred(env, pred); + case RexRN.And and -> onMatchAnd(env, and); + case RexRN.Or or -> onMatchOr(env, or); + case RexRN.Not not -> onMatchNot(env, not); + case RexRN.True literal -> onMatchTrue(env, literal); + case RexRN.False literal -> onMatchFalse(env, literal); + default -> onMatchCustom(env, pattern); + }; + } + + default E postMatch(E env) { + return env; + } + + default E preTransform(E env) { + return env; + } + + default E transform(E env, RelRN target) { + return switch (target) { + case RelRN.Scan scan -> transformScan(env, scan); + case RelRN.Filter filter -> transformFilter(env, filter); + case RelRN.Project project -> transformProject(env, project); + case RelRN.Join join -> transformJoin(env, join); + case RelRN.JoinWithPushedConds join -> transformJoinWithPushedConds(env, join); + case RelRN.Union union -> transformUnion(env, union); + case RelRN.Intersect intersect -> transformIntersect(env, intersect); + case RelRN.Minus minus -> transformMinus(env, minus); + case RelRN.Empty empty -> transformEmpty(env, empty); + case RelRN.Aggregate aggregate -> transformAggregate(env, aggregate); + default -> transformCustom(env, target); + }; + } + + default E transform(E env, RexRN target) { + return switch (target) { + case RexRN.Field field -> transformField(env, field); + case RexRN.JoinField joinField -> transformJoinField(env, joinField); + case RexRN.Pred pred -> transformPred(env, pred); + case RexRN.Proj proj -> transformProj(env, proj); + case RexRN.And and -> transformAnd(env, and); + case RexRN.Or or -> transformOr(env, or); + case RexRN.Not not -> transformNot(env, not); + case RexRN.True literal -> transformTrue(env, literal); + case RexRN.False literal -> transformFalse(env, literal); + default -> transformCustom(env, target); + }; + } + + default E postTransform(E env) { + return env; + } + + default String translate(String name, E onMatch, E transform) { + return "Unspecified translation to target language"; + } + + default String generate(RRule rule) { + System.out.printf("Generating Rule: %s\n", rule.name()); + var onMatch = postMatch(onMatch(preMatch(rule.name()), rule.before())); + var transform = postTransform(transform(preTransform(onMatch), rule.after())); + return translate(rule.name(), onMatch, transform); + } + + default E onMatchScan(E env, RelRN.Scan scan) { + return unimplementedOnMatch(env, scan); + } + + default E onMatchFilter(E env, RelRN.Filter filter) { + return unimplementedOnMatch(env, filter); + } + + default E onMatchProject(E env, RelRN.Project project) { + return unimplementedOnMatch(env, project); + } + + default E onMatchJoin(E env, RelRN.Join join) { + return unimplementedOnMatch(env, join); + } + + default E onMatchJoinWithSeparateConds(E env, RelRN.JoinWithSeparateConds join) { + return unimplementedOnMatch(env, join); + } + + default E onMatchUnion(E env, RelRN.Union union) { + return unimplementedOnMatch(env, union); + } + + default E onMatchIntersect(E env, RelRN.Intersect intersect) { + return unimplementedOnMatch(env, intersect); + } + + default E onMatchMinus(E env, RelRN.Minus minus) { + return unimplementedOnMatch(env, minus); + } + + default E onMatchCustom(E env, RelRN custom) { + return unimplementedOnMatch(env, custom); + } + + default E onMatchField(E env, RexRN.Field field) { + return unimplementedOnMatch(env, field); + } + + default E onMatchJoinField(E env, RexRN.JoinField joinField) { + return unimplementedOnMatch(env, joinField); + } + + default E onMatchPred(E env, RexRN.Pred pred) { + return unimplementedOnMatch(env, pred); + } + + default E onMatchProj(E env, RexRN.Proj proj) { + return unimplementedOnMatch(env, proj); + } + + default E onMatchAnd(E env, RexRN.And and) { + return unimplementedOnMatch(env, and); + } + + default E onMatchOr(E env, RexRN.Or or) { + return unimplementedOnMatch(env, or); + } + + default E onMatchNot(E env, RexRN.Not not) { + return unimplementedOnMatch(env, not); + } + + default E onMatchCustom(E env, RexRN custom) { + return unimplementedOnMatch(env, custom); + } + + default E onMatchTrue(E env, RexRN literal) { + return unimplementedOnMatch(env, literal); + } + + default E onMatchFalse(E env, RexRN literal) { + return unimplementedOnMatch(env, literal); + } + + default E onMatchEmpty(E env, RelRN.Empty empty) { + return unimplementedOnMatch(env, empty); + } + + default E transformScan(E env, RelRN.Scan scan) { + return unimplementedTransform(env, scan); + } + + default E transformFilter(E env, RelRN.Filter filter) { + return unimplementedTransform(env, filter); + } + + default E transformProject(E env, RelRN.Project project) { + return unimplementedTransform(env, project); + } + + default E transformJoin(E env, RelRN.Join join) { + return unimplementedTransform(env, join); + } + + default E transformJoinWithPushedConds(E env, RelRN.JoinWithPushedConds join) { + return unimplementedTransform(env, join); + } + + default E transformUnion(E env, RelRN.Union union) { + return unimplementedTransform(env, union); + } + + default E transformIntersect(E env, RelRN.Intersect intersect) { + return unimplementedTransform(env, intersect); + } + + default E transformMinus(E env, RelRN.Minus minus) { + return unimplementedTransform(env, minus); + } + + default E transformCustom(E env, RelRN custom) { + return unimplementedTransform(env, custom); + } + + default E transformField(E env, RexRN.Field field) { + return unimplementedTransform(env, field); + } + + default E transformJoinField(E env, RexRN.JoinField joinField) { + return unimplementedTransform(env, joinField); + } + + default E transformProj(E env, RexRN.Proj proj) { + return unimplementedTransform(env, proj); + } + + default E transformPred(E env, RexRN.Pred pred) { + return unimplementedTransform(env, pred); + } + + default E transformAnd(E env, RexRN.And and) { + return unimplementedTransform(env, and); + } + + default E transformOr(E env, RexRN.Or or) { + return unimplementedTransform(env, or); + } + + default E transformNot(E env, RexRN.Not not) { + return unimplementedTransform(env, not); + } + + default E transformCustom(E env, RexRN custom) { + return unimplementedTransform(env, custom); + } + + default E transformTrue(E env, RexRN literal) { + return unimplementedTransform(env, literal); + } + + default E transformFalse(E env, RexRN literal) { + return unimplementedTransform(env, literal); + } + + default E transformEmpty(E env, RelRN.Empty empty) { + return unimplementedTransform(env, empty); + } + + default E onMatchAggregate(E env, RelRN.Aggregate aggregate) { + return unimplementedOnMatch(env, aggregate); + } + + default E transformAggregate(E env, RelRN.Aggregate aggregate) { + return unimplementedTransform(env, aggregate); + } +} diff --git a/src/main/java/org/qed/Env.java b/src/main/java/org/qed/Env.java new file mode 100644 index 0000000..3f6f1e9 --- /dev/null +++ b/src/main/java/org/qed/Env.java @@ -0,0 +1,39 @@ +package org.qed; + +import kala.collection.Seq; +import kala.collection.immutable.ImmutableMap; +import kala.collection.mutable.MutableList; +import org.apache.calcite.rel.core.CorrelationId; + +import java.util.Set; + +record Env(int base, int delta, ImmutableMap globals, MutableList tables) { + static Env empty() { + return new Env(0, 0, ImmutableMap.empty(), MutableList.create()); + } + + Env recorded(Set ids) { + return new Env(base, delta, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, base)), tables); + } + + Env advanced(int d) { + return new Env(base + delta, d, globals, tables); + } + + Env lifted(int d) { + return new Env(base + d, delta, globals, tables); + } + + int resolve(QedTable table) { + var idx = tables.indexOf(table); + if (idx == -1) { + idx = tables.size(); + tables.append(table); + } + return idx; + } + + int resolve(CorrelationId id) { + return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + } +} diff --git a/src/main/java/org/qed/JSONDeserializer.java b/src/main/java/org/qed/JSONDeserializer.java new file mode 100644 index 0000000..a121124 --- /dev/null +++ b/src/main/java/org/qed/JSONDeserializer.java @@ -0,0 +1,365 @@ +package org.qed; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import kala.collection.Seq; +import kala.collection.Set; +import kala.collection.immutable.ImmutableSeq; +import kala.control.Try; +import kala.function.CheckedFunction; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.type.*; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSubQuery; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; + +import java.io.File; +import java.text.NumberFormat; +import java.util.List; +import java.util.Objects; + +public record JSONDeserializer() { + private final static ObjectMapper mapper = new ObjectMapper(); + + private static ImmutableSeq array(JsonNode node) throws Exception { + if (!node.isArray()) throw new Exception(); + return ImmutableSeq.from(node.elements()); + } + + private static ImmutableSeq array(JsonNode node, String path) throws Exception { + return array(node.required(path)); + } + + private static String string(JsonNode node) throws Exception { + if (!node.isTextual()) throw new Exception(); + return node.asText(); + } + + private static String string(JsonNode node, String path) throws Exception { + return string(node.required(path)); + } + + private static int integer(JsonNode node) throws Exception { + if (!node.isInt()) throw new Exception(); + return node.asInt(); + } + + private static int integer(JsonNode node, String path) throws Exception { + return integer(node.required(path)); + } + + private static boolean bool(JsonNode node) throws Exception { + if (!node.isBoolean()) throw new Exception(); + return node.asBoolean(); + } + + static SqlTypeName typeName(String name) { + name = switch (name) { + case "BOOL" -> "BOOLEAN"; + case "INT", "INT2", "INT4", "OID" -> "INTEGER"; + case "TIMESTAMPTZ" -> "TIMESTAMP"; + case "TIMETZ" -> "TIME"; + case "STRING" -> "VARCHAR"; + case "JSONB" -> "VARBINARY"; + default -> name; + }; + return Enum.valueOf(SqlTypeName.class, name); + } + + public static ImmutableSeq load(File file) throws Exception { + return new JSONDeserializer().deserialize(mapper.readTree(file)); + } + + public static void main(String[] args) throws Exception { + var refs = Seq.from(new File("RelOptRulesTest").listFiles()); + for (var file : refs) { + try { + var store = mapper.readTree(file); + new JSONDeserializer().deserialize(store); + } catch (Exception e) { + System.err.println("===> " + file.getName() + " <==="); + System.err.println(e.getMessage()); + System.err.println(); + } + } + } + + public ImmutableSeq deserialize(JsonNode node) throws Exception { + var builder = RuleBuilder.create(); + var tables = array(node, "schemas").mapChecked(schema -> { + var types = array(schema, "types").mapChecked(JSONDeserializer::string); + var nullabilities = array(schema, "nullable").mapChecked(JSONDeserializer::bool); + var name = schema.path("name").asText("DEFAULT_TABLE_NAME"); + var fields = schema.get("fields") == null ? + Seq.fill(types.size(), i -> String.format("DEFAULT_FIELD_NAME_%d", i)) : + array(schema, "fields").mapChecked(JSONDeserializer::string); + var keys = Set.from(array(schema, "key").map( + CheckedFunction.of(key -> ImmutableBitSet.of(array(key).mapChecked(JSONDeserializer::integer))))); + if (types.size() != nullabilities.size()) + throw new Exception("Expecting corresponding types and nullabilities"); + var sts = types.zip(nullabilities).map(tn -> { + var type = builder.getTypeFactory().createSqlType(typeName(tn.component1())); + return builder.getTypeFactory().createTypeWithNullability(type, tn.component2()); + }); + var table = new QedTable(name, fields, sts, keys, Set.empty()); + builder.addTable(table); + return table; + }); + var rel = new Rel(builder, ImmutableSeq.empty(), tables); + return array(node, "queries").mapChecked(rel); + } + + private record Rel(RuleBuilder builder, ImmutableSeq globals, ImmutableSeq tables) + implements CheckedFunction { + Rel(RuleBuilder builder) { + this(builder, ImmutableSeq.empty(), ImmutableSeq.empty()); + } + + RexCorrelVariable corr(RelDataType type) { + return (RexCorrelVariable) builder().getRexBuilder() + .makeCorrel(type, builder().getCluster().createCorrel()); + } + + Rel lifted(RexNode corr) { + var rex = builder().getRexBuilder(); + var vars = ImmutableSeq.fill(corr.getType().getFieldCount(), i -> rex.makeFieldAccess(corr, i)); + return new Rel(builder(), globals().appendedAll(vars), tables()); + } + + Rex rex(RexCorrelVariable local) { + return new Rex(builder(), globals(), local, tables()); + } + + Rex rex() { + var empty = (RexCorrelVariable) builder().getRexBuilder() + .makeCorrel(builder().getTypeFactory().createStructType(List.of()), + builder().getCluster().createCorrel()); + return new Rex(builder(), globals(), empty, tables()); + } + + JoinRelType kind(String k) throws Exception { + return Enum.valueOf(JoinRelType.class, k); + } + + public RelNode applyChecked(JsonNode node) throws Exception { + return deserialize(node); + } + + public RelNode deserialize(JsonNode node) throws Exception { + var entry = node.fields().next(); + var kind = entry.getKey(); + var content = entry.getValue(); + return switch (kind) { + case "scan" -> builder().scan(tables().get(integer(content)).getName()).build(); + case "values" -> { + var et = array(content, "schema"); + var rt = new RelRecordType(StructKind.FULLY_QUALIFIED, et.mapIndexedChecked( + (i, t) -> (RelDataTypeField) new RelDataTypeFieldImpl(String.format("VALUES-%s", i), i, + RelType.fromString(string(t), true))).asJava()); + var tuples = array(content, "content").mapChecked( + v -> array(v).mapChecked(jl -> (RexLiteral) rex().deserialize(jl)).asJava()); + yield builder().values(tuples.asJava(), rt).build(); + } + case "filter" -> { + var input = deserialize(content.required("source")); + var corr = corr(input.getRowType()); + var cond = rex(corr).deserialize(content.required("condition")); + yield builder().push(input).filter(Seq.of(corr.id), cond).build(); + } + case "project" -> { + var input = deserialize(content.required("source")); + var corr = corr(input.getRowType()); + var rex = rex(corr); + var projections = array(content, "target").mapChecked(rex); + yield builder().push(input).project(projections, Seq.empty(), false, Seq.of(corr.id)).build(); + } + case "join" -> { + var left = deserialize(content.required("left")); + var right = deserialize(content.required("right")); + var corr = corr(builder().getTypeFactory().createJoinType(left.getRowType(), right.getRowType())); + var cond = rex(corr).deserialize(content.required("condition")); + yield LogicalJoin.create(left, right, ImmutableList.of(), cond, Set.of(corr.id).asJava(), + kind(string(content, "kind"))); + } + case "correlate" -> { + var left = deserialize(content.required("left")); + var corr = corr(left.getRowType()); + var right = lifted(corr).deserialize(content.required("right")); + var rex = builder().getRexBuilder(); + var required = + Seq.from(RelOptUtil.correlationColumns(corr.id, right)).map(i -> rex.makeInputRef(left, i)); + yield builder().push(left).push(right).correlate(kind(string(content, "kind")), corr.id, required) + .build(); + } + case "union" -> { + var inputs = array(content).mapChecked(this::deserialize); + yield builder().pushAll(inputs).union(true, inputs.size()).build(); + } + case "intersect" -> { + var inputs = array(content).mapChecked(this::deserialize); + yield builder().pushAll(inputs).intersect(false, inputs.size()).build(); + } + case "except" -> { + var inputs = array(content).mapChecked(this::deserialize); + yield builder().pushAll(inputs).minus(false, inputs.size()).build(); + } + case "distinct" -> builder().push(deserialize(content)).distinct().build(); + case "group" -> { + var input = deserialize(content.required("source")); + var rex = rex(corr(input.getRowType())); + var keys = builder().groupKey(array(content, "keys").mapChecked(rex)); + yield builder().push(input).aggregate(keys, array(content, "function").mapChecked(rex::agg)) + .build(); + } + case "sort" -> { + var input = deserialize(content.required("source")); + var collations = RelCollations.of(array(content, "collation").mapChecked(coll -> { + var c = array(coll); + var col = integer(c.get(0)); + var ord = string(c.get(2)); + return new RelFieldCollation(col, Enum.valueOf(RelFieldCollation.Direction.class, ord)); + }).asJava()); + var sorted = builder().push(input).sort(collations).build(); + if (content.get("limit") == null) yield sorted; + yield builder().push(sorted).sortLimit(rex().deserialize(content.required("offset")), + rex().deserialize(content.required("limit")), Seq.empty()).build(); + } + default -> throw new Exception(String.format("Unrecognized node:\n%s", node.toPrettyString())); + }; + } + } + + private record Rex(RuleBuilder builder, ImmutableSeq globals, RexCorrelVariable local, + ImmutableSeq tables) implements CheckedFunction { + static Seq ops = Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) + .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && + java.lang.reflect.Modifier.isStatic(f.getModifiers())).map(f -> { + var mist = Try.of(() -> f.get(null)).getOrNull(); + if (mist == null) return null; + if (mist instanceof SqlOperator op) return op; + return null; + }).filter(Objects::nonNull); + + public RexNode resolve(int lvl) { + assert lvl < globals().size() + local().getType().getFieldCount(); + return lvl < globals().size() ? globals().get(lvl) : builder().getRexBuilder() + .makeInputRef(local().getType().getFieldList().get(lvl - globals().size()).getType(), + lvl - globals().size()); + } + + public Rel rel() { + var rex = builder().getRexBuilder(); + var locals = Seq.fill(local().getType().getFieldCount(), i -> rex.makeFieldAccess(local(), i)); + return new Rel(builder(), globals().appendedAll(locals), tables()); + } + + public RexNode applyChecked(JsonNode node) throws Exception { + return deserialize(node); + } + + public RelDataType type(String name) { + return builder().getTypeFactory().createSqlType(typeName(name)); + } + + SqlOperator op(String name, int arity) throws Exception { + switch (name) { + case "BOOL_AND" -> { + return SqlStdOperatorTable.AND; + } + case "MINUS" -> { + return SqlStdOperatorTable.MINUS; + } + case "UNARY MINUS" -> { + return SqlStdOperatorTable.UNARY_MINUS; + } + case "PLUS" -> { + return SqlStdOperatorTable.PLUS; + } + case "UNARY PLUS" -> { + return SqlStdOperatorTable.UNARY_PLUS; + } + case "+" -> { + if (arity == 2) { + return SqlStdOperatorTable.PLUS; + } else if (arity == 1) { + return SqlStdOperatorTable.UNARY_PLUS; + } + } + case "-" -> { + if (arity == 2) { + return SqlStdOperatorTable.MINUS; + } else if (arity == 1) { + return SqlStdOperatorTable.UNARY_MINUS; + } + } + } + var finalName = switch (name) { + case "EQ" -> "="; + case "GT" -> ">"; + case "LT" -> "<"; + case "GE" -> ">="; + case "LE" -> "<="; + case "MULT" -> "*"; + case "DIV" -> "/"; + case "IS", "<=>" -> "IS NOT DISTINCT FROM"; + case "IS NOT" -> "IS DISTINCT FROM"; + default -> name; + }; + var candicates = ops.filter(op -> op.getName().equals(finalName)); + if (candicates.isEmpty()) throw new Exception(String.format("Unknown operator name %s.", name)); + if (candicates.size() > 1) throw new Exception(String.format("Ambiguous operator name %s.", name)); + return candicates.first(); + } + + RelBuilder.AggCall agg(JsonNode node) throws Exception { + return builder().aggregateCall((SqlAggFunction) op(string(node, "operator"), 1), + array(node, "operand").mapChecked(this::deserialize)); + } + + public RexNode deserialize(JsonNode node) throws Exception { + var rex = builder().getRexBuilder(); + if (node.has("column")) { + return resolve(integer(node, "column")); + } else if (node.has("query")) { + var operator = string(node, "operator"); + var operands = array(node, "operand").mapChecked(this); + var query = rel().deserialize(node.required("query")); + return switch (operator.toLowerCase()) { + case "exists" -> RexSubQuery.exists(query); + case "unique" -> RexSubQuery.unique(query); + case "in" -> builder().in(query, operands); + default -> throw new Exception(String.format("Unknown subquery %s", operator)); + }; + } else { + var operator = string(node, "operator"); + var operands = array(node, "operand"); + var type = type(string(node, "type")); + if (operands.isEmpty()) { + return switch (operator.toLowerCase()) { + case "null" -> rex.makeNullLiteral(type); + case String lit -> Try.of(() -> rex.makeLiteral(Boolean.parseBoolean(lit), type)).getOrElse( + () -> Try.of(() -> rex.makeLiteral(NumberFormat.getInstance().parse(lit), type)) + .getOrElse(() -> rex.makeLiteral(lit, type))); + }; + } else { + return builder().getRexBuilder().makeCall(type, op(operator, operands.size()), + operands.mapChecked(this::deserialize).asJava()); + } + } + } + } +} + diff --git a/src/main/java/org/qed/JSONSerializer.java b/src/main/java/org/qed/JSONSerializer.java new file mode 100644 index 0000000..1d077f8 --- /dev/null +++ b/src/main/java/org/qed/JSONSerializer.java @@ -0,0 +1,228 @@ +package org.qed; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.*; +import kala.collection.Map; +import kala.collection.Seq; +import kala.collection.immutable.ImmutableMap; +import kala.collection.mutable.MutableList; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.*; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.*; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.Set; + +public record JSONSerializer(Env env) { + private final static ObjectMapper mapper = new ObjectMapper(); + + private static ArrayNode array(Seq objs) { + return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); + } + + private static ObjectNode object(Map fields) { + return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); + } + + private static BooleanNode bool(boolean b) { + return BooleanNode.valueOf(b); + } + + private static TextNode string(String s) { + return new TextNode(s); + } + + private static TextNode type(RelDataType type) { + if (type instanceof RelType.VarType varType) { + return new TextNode(varType.getName()); + } + return new TextNode(type.getSqlTypeName().getName()); + } + + private static IntNode integer(int i) { + return new IntNode(i); + } + + private static String qualifiedTableName(RelOptTable table) { + return Seq.from(table.getQualifiedName()).joinToString("."); + } + + public static ObjectNode serialize(Seq relNodes) { + var shuttle = new Rel(); + var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); + var queries = array(relNodes.map(shuttle::serialize)); + var tables = shuttle.env.tables(); + var schemas = array(tables.map(table -> { + var visitor = new Rex(shuttle.env.rex(table.getRowType().getFieldCount())); + var qedTable = table.unwrap(QedTable.class); + var fields = Seq.from(table.getRowType().getFieldList()); + return qedTable == null ? + object(Map.of("name", string(qualifiedTableName(table)), "fields", + array(fields.map(field -> string(field.getName()))), "types", + array(fields.map(field -> type(field.getType()))), "nullable", + array(fields.map(field -> bool(field.getType().isNullable()))), "key", + array((table.getKeys() != null ? Seq.from(table.getKeys()) : + Seq.empty()).map( + key -> array(Seq.from(key).map(JSONSerializer::integer)))), "guaranteed", + array(Seq.empty()))) : object(Map.of("name", string(qedTable.getName()), "fields", + array(qedTable.getColumnNames().map(JSONSerializer::string)), "types", + array(qedTable.getColumnTypes().map(JSONSerializer::type)), "nullable", + array(qedTable.getColumnTypes().map(type -> bool(type.isNullable()))), "key", + array(Seq.from(qedTable.getKeys().map(key -> array(Seq.from(key).map(JSONSerializer::integer))))), + "guaranteed", array(qedTable.getConstraints().map(visitor::serialize).toImmutableSeq()))); + })); + + return object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + } + + private record Rel(Env env) { + Rel() { + this(new Env(0, ImmutableMap.empty(), MutableList.create())); + } + + public JsonNode serialize(RelNode rel) { + return switch (rel) { + case TableScan scan -> object(Map.of("scan", integer(env.resolve(scan.getTable())))); + case LogicalValues values -> { + var visitor = new Rex(env.rex(0)); + var schema = + array(Seq.from(values.getRowType().getFieldList()).map(field -> type(field.getType()))); + var records = array(Seq.from(values.getTuples()) + .map(tuple -> array(Seq.from(tuple).map(visitor::serialize)))); + yield object(Map.of("values", object(Map.of("schema", schema, "content", records)))); + } + case LogicalFilter filter -> { + var input = filter.getInput(); + var visitor = + new Rex(env.recorded(filter.getVariablesSet()).rex(input.getRowType().getFieldCount())); + yield object(Map.of("filter", + object(Map.of("condition", visitor.serialize(filter.getCondition()), "source", + serialize(input))))); + } + case LogicalProject project -> { + var input = project.getInput(); + var visitor = + new Rex(env.recorded(project.getVariablesSet()).rex(input.getRowType().getFieldCount())); + var targets = array(Seq.from(project.getProjects()).map(visitor::serialize)); + yield object(Map.of("project", object(Map.of("target", targets, "source", serialize(input))))); + } + case LogicalJoin join -> { + var left = join.getLeft(); + var right = join.getRight(); + var visitor = new Rex(env.recorded(join.getVariablesSet()) + .rex(left.getRowType().getFieldCount() + right.getRowType().getFieldCount())); + yield object(Map.of("join", + object(Map.of("kind", string(join.getJoinType().toString()), "condition", + visitor.serialize(join.getCondition()), "left", serialize(left), "right", + serialize(right))))); + } + case LogicalCorrelate correlate -> { + var left = correlate.getLeft(); + var rightVisitor = new Rel(env.recorded(correlate.getVariablesSet()) + .lifted(left.getRowType().getFieldCount())); + yield object(Map.of("correlate", + object(Map.of("kind", string(correlate.getJoinType().toString()), "left", serialize(left), + "right", rightVisitor.serialize(correlate.getRight()))))); + } + case LogicalAggregate aggregate -> { + var level = env.lvl(); + var input = aggregate.getInput(); + var inputTypes = Seq.from(input.getRowType().getFieldList()).map(field -> type(field.getType())); + var keys = array(Seq.from(aggregate.getGroupSet()) + .map(col -> object(Map.of("column", integer(level + col), "type", inputTypes.get(col))))); + var aggs = array(Seq.from(aggregate.getAggCallList()).map(call -> object( + Map.of("operator", string(call.getAggregation().getName()), "operand", + array(Seq.from(call.getArgList()).map(col -> object( + Map.of("column", integer(level + col), "type", inputTypes.get(col))))), + "distinct", bool(call.isDistinct()), "ignoreNulls", bool(call.ignoreNulls()), + "type", type(call.getType()))))); + yield object(Map.of("group", + object(Map.of("keys", keys, "function", aggs, "source", serialize(input))))); + } + case LogicalUnion union -> { + var result = object(Map.of("union", array(Seq.from(union.getInputs()).map(this::serialize)))); + yield union.all ? result : object(Map.of("distinct", result)); + } + case LogicalIntersect intersect when !intersect.all -> + object(Map.of("intersect", array(Seq.from(intersect.getInputs()).map(this::serialize)))); + case LogicalMinus minus when !minus.all -> + object(Map.of("except", array(Seq.from(minus.getInputs()).map(this::serialize)))); + case LogicalSort sort -> { + var input = sort.getInput(); + var types = Seq.from(input.getRowType().getFieldList()).map(field -> type(field.getType())); + var collations = array(Seq.from(sort.collation.getFieldCollations()).map(collation -> { + var index = collation.getFieldIndex(); + return array(Seq.of(integer(index), types.get(index), string(collation.getDirection().name()))); + })); + var visitor = new Rex(env.rex(0)); + yield object(Map.of("sort", + object(Map.of("collation", collations, "source", serialize(input), "offset", + sort.offset != null ? visitor.serialize(sort.offset) : NullNode.instance, "limit", + sort.fetch != null ? visitor.serialize(sort.fetch) : NullNode.instance)))); + } + default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); + }; + } + + private record Env(int lvl, ImmutableMap globals, MutableList tables) { + Env recorded(Set ids) { + return new Env(lvl, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, lvl)), tables); + } + + Env lifted(int d) { + return new Env(lvl + d, globals, tables); + } + + int resolve(RelOptTable table) { + var idx = tables.map(JSONSerializer::qualifiedTableName).indexOf(qualifiedTableName(table)); + if (idx == -1) { + idx = tables.size(); + tables.append(table); + } + return idx; + } + + public Rex.Env rex(int delta) { + return new Rex.Env(lvl, delta, globals, tables); + } + } + } + + private record Rex(Env env) { + public JsonNode serialize(RexNode rex) { + return switch (rex) { + case RexInputRef inputRef -> object(Map.of("column", integer(inputRef.getIndex() + env.base()), "type", + type(inputRef.getType()))); + case RexLiteral literal -> object(Map.of("operator", + string(literal.getValue() == null ? "NULL" : literal.getValue().toString()), "operand", + array(Seq.empty()), "type", type(literal.getType()))); + case RexSubQuery subQuery -> + object(Map.of("operator", string(subQuery.getOperator().getName()), "operand", + array(Seq.from(subQuery.getOperands()).map(this::serialize)), "query", + new Rel(env.rel()).serialize(subQuery.rel), "type", type(subQuery.getType()))); + case RexCall call -> object(Map.of("operator", string(call.getOperator().getName()), "operand", + array(Seq.from(call.getOperands()).map(this::serialize)), "type", type(call.getType()))); + case RexFieldAccess fieldAccess -> object(Map.of("column", integer(fieldAccess.getField().getIndex() + + env.resolve(((RexCorrelVariable) fieldAccess.getReferenceExpr()).id)), "type", + type(fieldAccess.getType()))); + default -> throw new RuntimeException("Not implemented: " + rex.getKind()); + }; + } + + private record Env(int base, int delta, ImmutableMap globals, + MutableList tables) { + public Rel.Env rel() { + return new Rel.Env(base + delta, globals, tables); + } + + int resolve(CorrelationId id) { + return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + } + } + } +} diff --git a/src/main/java/org/qed/Main.java b/src/main/java/org/qed/Main.java new file mode 100644 index 0000000..5405713 --- /dev/null +++ b/src/main/java/org/qed/Main.java @@ -0,0 +1,150 @@ +package org.qed; + +import org.apache.calcite.tools.RelBuilder; +import org.apache.commons.io.FilenameUtils; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + + +/** + * The Main logic of Qed-Parser. + */ +public class Main { + + public static void main(String[] args) { + for (String filename : args) { + parseFile(filename); + } + } + + /** + * Parse a file or a directory of files. + * + * @param path The input path. + */ + + public static void parseFile(String path) { + String type = FilenameUtils.getExtension(path); + if (type.equals("sql")) { + parseSQLFile(path); + } else if (type.equals("cos")) { + parseCOSFile(path); + } else { + File object = new File(path); + if (object.isDirectory()) { + for (File file : Objects.requireNonNull(object.listFiles())) { + parseFile(file.getPath()); + } + } + } + } + + /** + * Parse a .sql file. + * + * @param filename The input filename. + */ + + private static void parseSQLFile(String filename) { + try { + Pattern comment = Pattern.compile("--.*(\\r?\\n|$)"); + Scanner scanner = new Scanner(new File(filename)); + SchemaGenerator generator = new SchemaGenerator(); + SQLJSONParser parser = new SQLJSONParser(); + scanner.useDelimiter(Pattern.compile(";")); + while (scanner.hasNext()) { + String statement = comment.matcher(scanner.next()).replaceAll("\n").trim(); + if (!statement.isBlank()) { + try { + if (statement.toUpperCase().startsWith("CREATE")) { + generator.applyCreate(statement); + } else if (statement.toUpperCase().startsWith("DECLARE")) { + generator.applyDeclareFunction(statement); + } else { + parser.parseDML(generator.extractSchema(), statement); + } + } catch (Exception e) { + throw new Exception( + "In statement:\n" + statement.replaceAll("(?m)^", "\t") + "\n" + e.getMessage()); + } + } + } + String outputPath = FilenameUtils.getFullPath(filename) + FilenameUtils.getBaseName(filename); + var builder = RelBuilder.create(RawPlanner.generateConfig(generator.extractSchema())); + parser.dumpOutput(builder, outputPath); + scanner.close(); + } catch (Exception e) { + System.err.println("In file:\n\t" + filename); + System.err.println(e.toString().trim() + "\n"); + } + } + + /** + * Assuming that the .cos file is always in the following format:
+ * schema schema_name(column:int, ...);
+ * table table_name(schema_name);
+ * query _ `query_body`;
+ * The .cos file will be translated to a .sql file, which contains the translated SQL statement. + * Then the .sql file will be passed to parseSQLFile(...) and will not be deleted after it is used. + * + * @param filename The input .cos filename + */ + private static void parseCOSFile(String filename) { + try { + Scanner scanner = new Scanner(new File(filename)); + Pattern schemaPattern = Pattern.compile("(?<=schema\\s)(\\w+)\\((.*)\\)$"); + Pattern tablePattern = Pattern.compile("(?<=table\\s)(\\w+)\\((\\w+)\\)$"); + Pattern declarationPattern = Pattern.compile("(\\w+):\\w+,?\\s?"); + Pattern queryPattern = Pattern.compile("(?<=`)[\\s\\S]*(?=`)"); + StringBuilder sqlBuilder = new StringBuilder(); + Map schemas = new HashMap<>(); + scanner.useDelimiter(Pattern.compile(";")); + while (scanner.hasNext()) { + String line = scanner.next(); + Matcher schemaMatcher = schemaPattern.matcher(line); + Matcher tableMatcher = tablePattern.matcher(line); + Matcher queryMatcher = queryPattern.matcher(line); + if (schemaMatcher.find()) { + StringBuilder schema = new StringBuilder(); + Matcher declarationMatcher = declarationPattern.matcher(schemaMatcher.group(2)); + schema.append(" ("); + while (declarationMatcher.find()) { + schema.append("\n\t"); + schema.append(declarationMatcher.group(1).toUpperCase(Locale.ROOT)); + schema.append(" INTEGER,"); + } + if (declarationMatcher.reset().find()) { + schema.deleteCharAt(schema.length() - 1); + } else { + schema.append("\n\tCOL INTEGER"); + } + schema.append("\n);\n"); + schemas.put(schemaMatcher.group(1).toUpperCase(Locale.ROOT), schema.toString()); + } else if (tableMatcher.find()) { + sqlBuilder.append("CREATE TABLE "); + sqlBuilder.append(tableMatcher.group(1).toUpperCase(Locale.ROOT)); + sqlBuilder.append(schemas.get(tableMatcher.group(2).toUpperCase(Locale.ROOT))); + } else if (queryMatcher.find()) { + sqlBuilder.append(queryMatcher.group().toUpperCase(Locale.ROOT)); + sqlBuilder.append(";\n"); + } + } + scanner.close(); + String intermediate = FilenameUtils.getFullPath(filename) + FilenameUtils.getBaseName(filename) + ".sql"; + File sql = new File(intermediate); + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(sql)); + bufferedWriter.write(sqlBuilder.toString()); + bufferedWriter.close(); + parseSQLFile(intermediate); + } catch (Exception e) { + System.err.println("In file:\n\t" + filename); + System.err.println(e.toString().trim() + "\n"); + } + } + +} diff --git a/src/main/java/org/qed/ProjectPaths.java b/src/main/java/org/qed/ProjectPaths.java new file mode 100644 index 0000000..642518f --- /dev/null +++ b/src/main/java/org/qed/ProjectPaths.java @@ -0,0 +1,19 @@ +package org.qed; + +import java.nio.file.Path; + +/** + * Resolves repo-relative paths for codegen and tests. Maven sets {@code -Drulescript.basedir}; + * otherwise {@code user.dir} is used (run from repository root). + */ +public final class ProjectPaths { + private ProjectPaths() {} + + public static Path baseDir() { + String override = System.getProperty("rulescript.basedir"); + if (override != null && !override.isBlank()) { + return Path.of(override); + } + return Path.of(System.getProperty("user.dir")); + } +} diff --git a/src/main/java/org/qed/QedTable.java b/src/main/java/org/qed/QedTable.java new file mode 100644 index 0000000..9617913 --- /dev/null +++ b/src/main/java/org/qed/QedTable.java @@ -0,0 +1,80 @@ +package org.qed; + +import kala.collection.Map; +import kala.collection.Seq; +import kala.collection.Set; +import kala.collection.immutable.ImmutableSet; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.schema.Statistic; +import org.apache.calcite.schema.Statistics; +import org.apache.calcite.schema.impl.AbstractTable; +import org.apache.calcite.util.ImmutableBitSet; + +public class QedTable extends AbstractTable { + private final String name; + + private final Seq columnNames; + + private final Seq columnTypes; + private final Set keys; + private final Set constraints; + + public QedTable(String n, Seq cn, Seq ct, Set k, Set cs) { + name = n; + columnNames = cn; + columnTypes = ct; + keys = k; + constraints = cs; + } + + public QedTable(String identifier, Map columns, Set> eligibleKeys, + Set checkConstraints) { + name = identifier; + columnNames = columns.keysView().toImmutableSeq().sorted(); + columnTypes = columnNames.map(columns::get); + keys = Set.from(eligibleKeys.map(key -> ImmutableBitSet.of(key.map(columnNames::indexOf)))); + constraints = checkConstraints; + } + + public QedTable(String identifier, Map columns, ImmutableSet eligibleKeys, + Set checkConstraints) { + name = identifier; + columnNames = columns.keysView().toImmutableSeq().sorted(); + columnTypes = columnNames.map(columns::get); + keys = eligibleKeys; + constraints = checkConstraints; + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + return typeFactory.createStructType(columnNames.zip(columnTypes).view() + .map(entry -> java.util.Map.entry(entry.component1(), entry.component2())).toImmutableSeq().asJava()); + } + + @Override + public Statistic getStatistic() { + return Statistics.of(0, keys.toImmutableSeq().asJava()); + } + + public String getName() { + return name; + } + + public Seq getColumnNames() { + return columnNames; + } + + public Seq getColumnTypes() { + return columnTypes; + } + + public Set getKeys() { + return keys; + } + + public Set getConstraints() { + return constraints; + } +} diff --git a/src/main/java/org/qed/RRule.java b/src/main/java/org/qed/RRule.java new file mode 100644 index 0000000..a548f9a --- /dev/null +++ b/src/main/java/org/qed/RRule.java @@ -0,0 +1,115 @@ +package org.qed; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import kala.collection.Map; +import kala.collection.Seq; + +import java.io.File; +import java.io.IOException; + +public interface RRule { + RelRN before(); + + RelRN after(); + + default String explain() { + return getClass().getName() + + "\n" + + before().semantics().explain() + + "=>" + + "\n" + + after().semantics().explain(); + } + + default String name() { + return getClass().getSimpleName(); + } + + default String info() { + return ""; + } + + default ObjectNode toJson() { + return JSONSerializer.serialize(Seq.of(before().semantics(), after().semantics())); + } + + default void dump(String path) throws IOException { + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(new File(path), toJson()); + } + + interface RRuleFamily { + Seq family(); + } + + record RRuleGenerator(RRule rule, + Seq assignments) implements RRuleFamily { + @Override + public Seq family() { + return assignments.map(assignment -> new RRule() { + + @Override + public RelRN before() { + return assignment.replaceMetaRelRN(rule.before()); + } + + @Override + public RelRN after() { + return assignment.replaceMetaRelRN(rule.after()); + } + + @Override + public String name() { + return rule.name(); + } + + @Override + public String info() { + return assignment.info(); + } + }); + } + + public record MetaAssignment( + Map joinTypeAssignment) { + public RelRN.Join.JoinType replaceMetaJoinType(RelRN.Join.JoinType joinType) { + return switch (joinType) { + case RelRN.Join.JoinType.MetaJoinType metaJoinType -> joinTypeAssignment.get(metaJoinType); + default -> joinType; + }; + } + + public RexRN replaceMetaRexRN(RexRN rexRN) { + return switch (rexRN) { + case RexRN.Field field -> replaceMetaRelRN(field.source()).field(field.ordinal()); + default -> customReplaceMetaRexRN(rexRN); + }; + } + + public RelRN replaceMetaRelRN(RelRN relRN) { + return switch (relRN) { + case RelRN.Filter filter -> + replaceMetaRelRN(filter.source()).filter(replaceMetaRexRN(filter.cond())); + case RelRN.Join join -> + replaceMetaRelRN(join.left()).join(replaceMetaJoinType(join.ty()), + replaceMetaRexRN(join.cond()), replaceMetaRelRN(join.right())); + default -> customReplaceMetaRelRN(relRN); + }; + } + + public RexRN customReplaceMetaRexRN(RexRN rexRN) { + return rexRN; + } + + public RelRN customReplaceMetaRelRN(RelRN relRN) { + return relRN; + } + + public String info() { + return joinTypeAssignment.joinToString("&", (m, c) -> "{" + m.name() + "}=" + c.semantics()); + } + + } + } +} + diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java new file mode 100644 index 0000000..f7e3711 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstance.java @@ -0,0 +1,481 @@ +package org.qed; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RRuleInstance.JoinAssociate; +// import org.qed.RRuleInstance.JoinConditionPush.JoinPred; + +public interface RRuleInstance { + record FilterIntoJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + var join = left.join(JoinRelType.INNER, joinCond, right); + return join.filter("outer"); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + } + } + + record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); + + @Override + public RelRN before() { + return source.filter(inner).filter(outer); + } + + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); + } + } + + record FilterProjectTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.filter(proj.pred("pred")).project(proj); + } + + @Override + public RelRN after() { + return source.project(proj).filter("pred"); + } + } + + record FilterReduceFalse() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.falseLiteral()); + } + + @Override + public RelRN after() { + return source.empty(); + } + } + + record FilterReduceTrue() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.trueLiteral()); + } + + @Override + public RelRN after() { + return source; + } + } + + // TBD: include intersect to make it a rule familiy + record FilterSetOpTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Common_Type"); + static final RelRN right = RelRN.scan("Right", "Common_Type"); + + @Override + public RelRN before() { + RelRN projTmp = left.union(false, right); + return projTmp.filter(projTmp.pred("filter")); + } + + @Override + public RelRN after() { + RexRN leftPred = left.pred("filter"); + RexRN rightPred = right.pred("filter"); + return left.filter(leftPred).union(false, right.filter(rightPred)); + } + } + + record IntersectMerge() implements RRule { + // Use a common type for all relations to make them compatible + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + // Nested INTERSECT: (A INTERSECT B) INTERSECT C + return a.intersect(false, b).intersect(false, c); + } + + @Override + public RelRN after() { + // Flattened INTERSECT: A INTERSECT B INTERSECT C + return a.intersect(false, b, c); + } + } + + // record JoinConditionPush() implements RRule { + // static final RelRN left = RelRN.scan("Left", "Left_Type"); + // static final RelRN right = RelRN.scan("Right", "Right_Type"); + // static final JoinPred joinPred = new JoinPred(left, right); + + // @Override + // public RelRN before() { + // return left.join(JoinRelType.INNER, joinPred, right); + // } + + // @Override + // public RelRN after() { + // var leftRN = left.filter(joinPred.leftPred()); + // var rightRN = right.filter(joinPred.rightPred()); + // return leftRN.join(JoinRelType.INNER, joinPred.bothPred(), rightRN); + // } + + // public record JoinPred(RelRN left, RelRN right) implements RexRN { + + // @Override + // public RexNode semantics() { + // return RexRN.and(left.joinPred(bothPred(), right), left.joinField(0, right).pred(leftPred()), + // left.joinField(1, right).pred(rightPred())).semantics(); + // } + + // public String bothPred() {return "both";} + + // public String leftPred() {return "left";} + + // public String rightPred() {return "right";} + + // } + // } + + record JoinAddRedundantSemiJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, pred, right).join(JoinRelType.INNER, pred, right); + } + } + + // Todo: explore join types, see line 102 of JoinAssociateRule + record JoinAssociate() implements RRule.RRuleFamily { + static final RelRN a = RelRN.scan("A", "A_Type"); + static final RelRN b = RelRN.scan("B", "B_Type"); + static final RelRN c = RelRN.scan("C", "C_Type"); + static final String pred_ab = "pred_ab"; + static final String pred_bc = "pred_bc"; + static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); + static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); + static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); + static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); + + static final RelRN before_ab = a.join(mjt_0, RexRN.and( + a.joinPred(pred_ab, b), + new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), b); + + static final RelRN before = before_ab.join(mjt_1, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), + new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after_bc = b.join(mjt_2, RexRN.and( + b.joinPred(pred_bc, c), + new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after = a.join(mjt_3, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), + new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), after_bc); + + static final RRule template = new RRule() { + @Override + public RelRN before() { + return before; + } + + @Override + public RelRN after() { + return after; + } + + @Override + public String name() { + return JoinAssociate.class.getSimpleName(); + } + }; + + static Seq assignments() { + var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); + return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); + } + + @Override + public Seq family() { + return new RRule.RRuleGenerator(template, assignments()).family(); + } + } + + record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("pred", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + // We need to swap the join fields in the condition + RexRN commutedJoinCond = right.joinPred("pred", left); + return right.join(JoinRelType.INNER, commutedJoinCond, left); + } + } + + record JoinExtractFilter() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.trueLiteral(), right).filter(joinCond); + } + } + +// record JoinProjectTranspose() implements RRule { +// +// } + + // JoinConditionPush? +// record JoinPushExpressions() implements RRule { +// +// } + + // JoinConditionPush? +// record JoinPushTransitivePredicates() implements RRule { +// +// } + +// record JoinToSemiJoin() implements RRule { +// +// } + +// record JoinLeftUnionTranspose() implements RRule { +// +// } + +// record JoinRightUnionTranspose() implements RRule { +// +// } + +// record ProjectJoinRemove() implements RRule { +// +// @Override +// public RelRN before() { +// return null; +// } +// +// @Override +// public RelRN after() { +// return null; +// } +// } + +// record ProjectJoinJoinRemove() implements RRule { +// +// } + + record ProjectJoinTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN proj = left.proj("proj", "Project_Type"); + static final String joinCond = left.joinPred("join", right).toString(); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.INNER, joinCond, right); + } + } + + record ProjectMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.proj("inner", "Inner_Type"); + static final String outer = "outer"; + static final String outerType = "Outer_Type"; + + @Override + public RelRN before() { + return source.project(inner).project(outer, outerType); + } + + @Override + public RelRN after() { + return source.project(inner.proj(outer, outerType)); + } + } + + //TBD: currently provable for UNION ALL while unprovable for UNION + // record ProjectSetOpTranspose() implements RRule { + // static final RelRN left = RelRN.scan("Left", "Common_Type"); + // static final RelRN right = RelRN.scan("Right", "Common_Type"); + + // @Override + // public RelRN before() { + // RelRN projTmp = left.union(true, right); + // return projTmp.project(projTmp.proj("proj", "Proj_Type")); + // } + + // @Override + // public RelRN after() { + // RelRN projA = left.project(left.proj("proj", "Proj_Type")); + // RelRN projB = right.project(right.proj("proj", "Proj_Type")); + // return projA.union(true, projB); + // } + // } + + + /* TBD: Already optimized by calcite? */ + // record ProjectRemove() implements RRule { + // static final RelRN source = RelRN.scan("Source", "Source_Type"); + + // @Override + // public RelRN before() { + // return source.project(source.field(0)); + // } + + // @Override + // public RelRN after() { + // return source; + // } + // } + + record UnionMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.union(false, b).union(false, c); + } + + @Override + public RelRN after() { + return null; + } + } + + record SemiJoinFilterTranspose() implements RRule { + static final RelRN left = RelRN.scan("left", "Left_Type"); + static final RelRN right = RelRN.scan("right", "Right_Type"); + static final RexRN pred = left.pred("pred"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, joinCond, right).filter(pred); + } + + @Override + public RelRN after() { + return left.filter(pred).join(JoinRelType.SEMI, joinCond, right); + } + } + + record SemiJoinJoinTranspose() implements RRule { + static final RelRN r = RelRN.scan("R", "R_Type"); + static final RelRN s = RelRN.scan("S", "S_Type"); + static final RelRN t = RelRN.scan("T", "T_Type"); + static final RexRN semiCond = r.joinPred("semi", s); + static final RexRN joinCond = r.joinPred("join", t); + + @Override + public RelRN before() { + return r.join(JoinRelType.INNER, joinCond, t).join(JoinRelType.SEMI, semiCond, s); + } + + @Override + public RelRN after() { + return r.join(JoinRelType.SEMI, semiCond, s).join(JoinRelType.INNER, joinCond, t); + } + } + + record SemiJoinProjectTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN proj = left.proj("proj", "Project_Type"); + static final RexRN semiCond = left.joinPred("semi", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, semiCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.SEMI, semiCond, right); + } + } + + record SemiJoinRemove() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); + } + + @Override + public RelRN after() { + return left; + } + } + +// record UnionMerge() implements RRule { +// +// } + +// record UnionRemove() implements RRule { +// +// } +} + +/* + * Semantically identical cases: + * FilterExpandIsNotDistinctFrom + * FilterScan + * JoinReduceExpression + * ProjectReduceExpression + * ProjectTableScan + */ diff --git a/src/main/java/org/qed/RRuleInstances/AggregateExtractProject.java b/src/main/java/org/qed/RRuleInstances/AggregateExtractProject.java new file mode 100644 index 0000000..97037fd --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/AggregateExtractProject.java @@ -0,0 +1,23 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RRule; +import org.qed.RexRN; + +public record AggregateExtractProject() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.aggregate( + proj.groupBy("groupByName"), + proj.aggCall("aggName") + ); + } + + @Override + public RelRN after() { + return source.project(proj).aggregate("groupByName", "aggName"); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/AggregateFilterTranspose.java b/src/main/java/org/qed/RRuleInstances/AggregateFilterTranspose.java new file mode 100644 index 0000000..6eaba5d --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/AggregateFilterTranspose.java @@ -0,0 +1,22 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record AggregateFilterTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN.GroupBy groupExpr = source.groupBy("groupByName"); + + @Override + public RelRN before() { + return source.filter(groupExpr.pred("pred")) + .aggregate(groupExpr, source.aggCall("aggName")); + } + + @Override + public RelRN after() { + RelRN aggregated = source.aggregate(groupExpr, source.aggCall("aggName")); + return aggregated.filter(aggregated.field(0).pred("pred")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/AggregateJoinJoinRemove.java b/src/main/java/org/qed/RRuleInstances/AggregateJoinJoinRemove.java new file mode 100644 index 0000000..b935c3e --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/AggregateJoinJoinRemove.java @@ -0,0 +1,40 @@ +package org.qed.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RuleBuilder; + +public record AggregateJoinJoinRemove() implements RRule { + static final RelRN tblA = RelRN.scan("sourceA1", "typeA1") + .join(JoinRelType.INNER, RexRN.trueLiteral(), RelRN.scan("sourceA2", "typeA2")); + static final RelRN tblB = RelRN.scan("sourceB1", "typeB1") + .join(JoinRelType.INNER, RexRN.trueLiteral(), RelRN.scan("sourceB2", "typeB2")); + static final RelRN tblC = RelRN.scan("sourceC1", "typeC1") + .join(JoinRelType.INNER, RexRN.trueLiteral(), RelRN.scan("sourceC2", "typeC2")); + + static final RexRN bottomJoinCondition = new RexRN.Pred( + RuleBuilder.create().genericPredicateOp("=", true), + Seq.of(tblA.field(0), tblB.field(0)) + ); + + static final RexRN topJoinCondition = new RexRN.Pred( + RuleBuilder.create().genericPredicateOp("=", true), + Seq.of(tblA.field(0), tblC.field(0)) + ); + + @Override + public RelRN before() { + RelRN bottomJoin = tblA.join(JoinRelType.LEFT, bottomJoinCondition, tblB); + RelRN topJoin = bottomJoin.join(JoinRelType.LEFT, topJoinCondition, tblC); + return new RelRN.Aggregate(topJoin, Seq.of(topJoin.field(0), topJoin.field(4)), Seq.empty()); + } + + @Override + public RelRN after() { + RelRN newJoin = tblA.join(JoinRelType.LEFT, topJoinCondition, tblC); + return new RelRN.Aggregate(newJoin, Seq.of(newJoin.field(0), newJoin.field(2)), Seq.empty()); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/AggregateJoinRemove.java b/src/main/java/org/qed/RRuleInstances/AggregateJoinRemove.java new file mode 100644 index 0000000..28688e6 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/AggregateJoinRemove.java @@ -0,0 +1,32 @@ +package org.qed.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RuleBuilder; + +public record AggregateJoinRemove() implements RRule { + static final RelRN tblA = RelRN.scan("sourceA1", "typeA1") + .join(JoinRelType.INNER, RexRN.trueLiteral(), RelRN.scan("sourceA2", "typeA2")); + + static final RelRN tblB = RelRN.scan("sourceB1", "typeB1") + .join(JoinRelType.INNER, RexRN.trueLiteral(), RelRN.scan("sourceB2", "typeB2")); + + static final RexRN joinCondition = new RexRN.Pred( + RuleBuilder.create().genericPredicateOp("=", true), + Seq.of(tblA.field(0), tblB.field(0)) + ); + + @Override + public RelRN before() { + RelRN leftJoin = tblA.join(JoinRelType.LEFT, joinCondition, tblB); + return new RelRN.Aggregate(leftJoin, Seq.of(leftJoin.field(0)), Seq.empty()); + } + + @Override + public RelRN after() { + return new RelRN.Aggregate(tblA, Seq.of(tblA.field(0)), Seq.empty()); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/AggregateProjectConstantToDummyJoin.java b/src/main/java/org/qed/RRuleInstances/AggregateProjectConstantToDummyJoin.java new file mode 100644 index 0000000..2259aee --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/AggregateProjectConstantToDummyJoin.java @@ -0,0 +1,144 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RRule; +import org.qed.RuleBuilder; +import org.qed.RelType; +import kala.collection.Seq; +import kala.tuple.Tuple; + +public record AggregateProjectConstantToDummyJoin() implements RRule { + + static final RelRN source = new SourceTable(); + + @Override + public RelRN before() { + var projectWithConstants = new ProjectWithConstantLiterals(source); + return new AggregateGroupingByConstants(projectWithConstants); + } + + @Override + public RelRN after() { + var dummyTable = new DummyConstantsTable(); + var joinWithDummy = new JoinWithDummyTable(source, dummyTable); + var projectWithDummyFields = new ProjectWithDummyFields(joinWithDummy); + return new AggregateGroupingByDummyFields(projectWithDummyFields); + } + + public static record SourceTable() implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("Source_Type", true), false), + Tuple.of(RelType.fromString("Source_Type", true), false) + )); + + builder.addTable(table); + return builder.scan(table.getName()).build(); + } + } + + public static record ProjectWithConstantLiterals(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + builder.project( + builder.field(0), + builder.literal(true), + builder.literal("2024"), + builder.field(1) + ); + + return builder.build(); + } + } + + public static record AggregateGroupingByConstants(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + var groupKey = builder.groupKey( + builder.field(1), + builder.field(2), + builder.field(0) + ); + + var agg = builder.avg(builder.field(3)); + + builder.aggregate(groupKey, agg); + return builder.build(); + } + } + + public static record DummyConstantsTable() implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + builder.values( + new String[]{"col0", "col1"}, + true, + "2024" + ); + + return builder.build(); + } + } + + public static record JoinWithDummyTable(RelRN baseTable, RelRN dummyTable) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + builder.push(baseTable.semantics()); + builder.push(dummyTable.semantics()); + + builder.join(JoinRelType.INNER, builder.literal(true)); + + return builder.build(); + } + } + + public static record ProjectWithDummyFields(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + builder.project( + builder.field(0), + builder.field(2), + builder.field(3), + builder.field(1) + ); + + return builder.build(); + } + } + public static record AggregateGroupingByDummyFields(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + var groupKey = builder.groupKey( + builder.field(1), + builder.field(2), + builder.field(0) + ); + + var agg = builder.avg(builder.field(3)); + + builder.aggregate(groupKey, agg); + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/AggregateProjectMerge.java b/src/main/java/org/qed/RRuleInstances/AggregateProjectMerge.java new file mode 100644 index 0000000..cf3da67 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/AggregateProjectMerge.java @@ -0,0 +1,23 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RRule; +import org.qed.RexRN; + +public record AggregateProjectMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.project(proj).aggregate("groupByName", "aggName"); + } + + @Override + public RelRN after() { + return source.aggregate( + proj.groupBy("groupByName"), + proj.aggCall("aggName") + ); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/RRuleInstances/FilterAggregateTranspose.java new file mode 100644 index 0000000..aa3ada0 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterAggregateTranspose.java @@ -0,0 +1,22 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterAggregateTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN.GroupBy groupExpr = source.groupBy("groupByName"); + + @Override + public RelRN before() { + RelRN aggregated = source.aggregate(groupExpr, source.aggCall("aggName")); + return aggregated.filter(aggregated.field(0).pred("pred")); + } + + @Override + public RelRN after() { + return source.filter(groupExpr.pred("pred")) + .aggregate(groupExpr, source.aggCall("aggName")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java new file mode 100644 index 0000000..8dcbd70 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java @@ -0,0 +1,23 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterIntoJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + var join = left.join(JoinRelType.INNER, joinCond, right); + return join.filter("outer"); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterMerge.java b/src/main/java/org/qed/RRuleInstances/FilterMerge.java new file mode 100644 index 0000000..473d3c1 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterMerge.java @@ -0,0 +1,21 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); + + @Override + public RelRN before() { + return source.filter(inner).filter(outer); + } + + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java b/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java new file mode 100644 index 0000000..e7154b4 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java @@ -0,0 +1,20 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterProjectTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.project(proj).filter("pred"); + } + + @Override + public RelRN after() { + return source.filter(proj.pred("pred")).project(proj); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java b/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java new file mode 100644 index 0000000..ccaa73c --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java @@ -0,0 +1,19 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterReduceFalse() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.falseLiteral()); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java b/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java new file mode 100644 index 0000000..d67d14e --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java @@ -0,0 +1,19 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterReduceTrue() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.trueLiteral()); + } + + @Override + public RelRN after() { + return source; + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java b/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java new file mode 100644 index 0000000..4d9ea3d --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java @@ -0,0 +1,23 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterSetOpTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Common_Type"); + static final RelRN right = RelRN.scan("Right", "Common_Type"); + + @Override + public RelRN before() { + RelRN projTmp = left.union(false, right); + return projTmp.filter(projTmp.pred("filter")); + } + + @Override + public RelRN after() { + RexRN leftPred = left.pred("filter"); + RexRN rightPred = right.pred("filter"); + return left.filter(leftPred).union(false, right.filter(rightPred)); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/IntersectMerge.java b/src/main/java/org/qed/RRuleInstances/IntersectMerge.java new file mode 100644 index 0000000..a4ea680 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/IntersectMerge.java @@ -0,0 +1,29 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record IntersectMerge() implements RRule { + // Use a common type for all relations to make them compatible + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + // Nested INTERSECT: (A INTERSECT B) INTERSECT C + return a.intersect(false, b).intersect(false, c); + } + + @Override + public RelRN after() { + // Flattened INTERSECT: A INTERSECT B INTERSECT C + return a.intersect(false, b, c); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java b/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java new file mode 100644 index 0000000..86c6ca7 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java @@ -0,0 +1,21 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RRule; + +public record JoinAddRedundantSemiJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, pred, right).join(JoinRelType.INNER, pred, right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinCommute.java b/src/main/java/org/qed/RRuleInstances/JoinCommute.java new file mode 100644 index 0000000..ec54ca5 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinCommute.java @@ -0,0 +1,48 @@ +package org.qed.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + SqlOperator predOp = RuleBuilder.create().genericPredicateOp(pred, true); + RexRN swappedPred = new RexRN.Pred(predOp, Seq.of( + new RexRN.JoinField(1, right, left), + new RexRN.JoinField(0, right, left) + )); + RelRN swappedJoin = right.join(JoinRelType.INNER, swappedPred, left); + + return new ProjectionRelRN(swappedJoin); + } + public static record ProjectionRelRN(RelRN source) implements RelRN { + @Override + public RelNode semantics() { + RuleBuilder builder = RuleBuilder.create(); + builder.push(source.semantics()); + + RexNode leftField = builder.field(1); + RexNode rightField = builder.field(0); + + builder.project(leftField, rightField); + + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/JoinConditionPush.java b/src/main/java/org/qed/RRuleInstances/JoinConditionPush.java new file mode 100644 index 0000000..468aa86 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinConditionPush.java @@ -0,0 +1,56 @@ +package org.qed.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.SqlOperator; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinConditionPush() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + SqlOperator joinOp = RuleBuilder.create().genericPredicateOp("joinCond", true); + RexRN crossTableCond = new RexRN.Pred(joinOp, Seq.of( + new RexRN.JoinField(0, left, right), + new RexRN.JoinField(1, left, right) + )); + SqlOperator leftOp = RuleBuilder.create().genericPredicateOp("leftCond", true); + RexRN leftOnlyCond = new RexRN.Pred(leftOp, Seq.of( + new RexRN.JoinField(0, left, right) + )); + SqlOperator rightOp = RuleBuilder.create().genericPredicateOp("rightCond", true); + RexRN rightOnlyCond = new RexRN.Pred(rightOp, Seq.of( + new RexRN.JoinField(1, left, right) + )); + return left.joinWithSeparateConds(JoinRelType.INNER, + RexRN.and(crossTableCond, leftOnlyCond, rightOnlyCond), right); + } + + @Override + public RelRN after() { + SqlOperator joinOp = RuleBuilder.create().genericPredicateOp("joinCond", true); + RexRN crossTableCond = new RexRN.Pred(joinOp, Seq.of( + new RexRN.JoinField(0, left, right), + new RexRN.JoinField(1, left, right) + )); + + SqlOperator leftOp = RuleBuilder.create().genericPredicateOp("leftCond", true); + RexRN leftFilterCond = new RexRN.Pred(leftOp, Seq.of( + new RexRN.Field(0, left) + )); + RelRN filteredLeft = left.filter(leftFilterCond); + + SqlOperator rightOp = RuleBuilder.create().genericPredicateOp("rightCond", true); + RexRN rightFilterCond = new RexRN.Pred(rightOp, Seq.of( + new RexRN.Field(0, right) + )); + RelRN filteredRight = right.filter(rightFilterCond); + + return filteredLeft.joinWithPushedConds(JoinRelType.INNER, crossTableCond, filteredRight); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java b/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java new file mode 100644 index 0000000..fc78f17 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java @@ -0,0 +1,22 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record JoinExtractFilter() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.trueLiteral(), right).filter(joinCond); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinPushTransitivePredicates.java b/src/main/java/org/qed/RRuleInstances/JoinPushTransitivePredicates.java new file mode 100644 index 0000000..aa78ede --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinPushTransitivePredicates.java @@ -0,0 +1,23 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record JoinPushTransitivePredicates() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN cond1 = left.joinPred("cond1", right); + static final RexRN cond2 = left.joinPred("cond2", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, cond1, right).filter(cond2); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(cond1, cond2), right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinReduceFalse.java b/src/main/java/org/qed/RRuleInstances/JoinReduceFalse.java new file mode 100644 index 0000000..a537196 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinReduceFalse.java @@ -0,0 +1,20 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record JoinReduceFalse() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = RexRN.and(left.joinPred("join", right), RexRN.falseLiteral()); + + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.falseLiteral(), right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinReduceTrue.java b/src/main/java/org/qed/RRuleInstances/JoinReduceTrue.java new file mode 100644 index 0000000..61b2106 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinReduceTrue.java @@ -0,0 +1,21 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record JoinReduceTrue() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN afterJoinCond = left.joinPred("join", right); + static final RexRN beforeJoinCond = RexRN.and(afterJoinCond, RexRN.trueLiteral()); + + public RelRN before() { + return left.join(JoinRelType.INNER, beforeJoinCond, right); + } + + public RelRN after() { + return left.join(JoinRelType.INNER, afterJoinCond, right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/MinusMerge.java b/src/main/java/org/qed/RRuleInstances/MinusMerge.java new file mode 100644 index 0000000..9367c63 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/MinusMerge.java @@ -0,0 +1,20 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RRule; + +public record MinusMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.minus(false, b).minus(false, c); + } + + @Override + public RelRN after() { + return a.minus(false, b.union(false, c)); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/ProjectAggregateMerge.java b/src/main/java/org/qed/RRuleInstances/ProjectAggregateMerge.java new file mode 100644 index 0000000..1e2ad8e --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/ProjectAggregateMerge.java @@ -0,0 +1,121 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.RelNode; +import org.qed.RelRN; +import org.qed.RRule; +import org.qed.RuleBuilder; +import org.qed.RelType; +import kala.collection.Seq; +import kala.tuple.Tuple; + +public record ProjectAggregateMerge() implements RRule { + static final RelRN source = new SourceTable(); + + @Override + public RelRN before() { + var aggregateWithUnusedCalls = new AggregateWithMultipleCalls(source); + return new ProjectUsingSubsetOfAggregates(aggregateWithUnusedCalls); + } + + @Override + public RelRN after() { + var aggregateOptimized = new AggregateWithUsedCallsOnly(source); + return new ProjectOptimized(aggregateOptimized); + } + + public static record SourceTable() implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("Source_Type", true), false), + Tuple.of(RelType.fromString("Source_Type", true), false), + Tuple.of(RelType.fromString("Source_Type", true), false), + Tuple.of(RelType.fromString("Source_Type", true), false) + )); + + builder.addTable(table); + return builder.scan(table.getName()).build(); + } + } + + public static record AggregateWithMultipleCalls(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + var groupKey = builder.groupKey(builder.field(0)); + + var agg1 = builder.sum(false, "agg1", builder.field(1)); // Will be used + var agg2 = builder.avg(builder.field(2)); // Will be unused + var agg3 = builder.count(false, "agg3", builder.field(3)); // Will be used + var agg4 = builder.max(builder.field(1)); // Will be unused + + builder.aggregate(groupKey, agg1, agg2, agg3, agg4); + return builder.build(); + } + } + + public static record ProjectUsingSubsetOfAggregates(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + builder.project( + builder.field(0), + builder.field(1), + builder.field(3) + ); + + return builder.build(); + } + } + + /** + * Optimized aggregate with only used calls + * agg2 and agg4 are eliminated since they're not used + */ + public static record AggregateWithUsedCallsOnly(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + var groupKey = builder.groupKey(builder.field(0)); + + // Only the aggregate calls that are actually used + var agg1 = builder.sum(false, "agg1", builder.field(1)); // Used + var agg3 = builder.count(false, "agg3", builder.field(3)); // Used + // agg2 and agg4 removed - they were unused + + builder.aggregate(groupKey, agg1, agg3); + return builder.build(); + } + } + + /** + * Optimized project with adjusted field references + */ + public static record ProjectOptimized(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + // After optimization, field layout is: + // field(0) = group key + // field(1) = agg1 (was field 1, still field 1) + // field(2) = agg3 (was field 3, now field 2) + builder.project( + builder.field(0), + builder.field(1), + builder.field(2) + ); + + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/ProjectFilterTranspose.java b/src/main/java/org/qed/RRuleInstances/ProjectFilterTranspose.java new file mode 100644 index 0000000..2a652e6 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/ProjectFilterTranspose.java @@ -0,0 +1,20 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record ProjectFilterTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.filter(proj.pred("pred")).project(proj); + } + + @Override + public RelRN after() { + return source.project(proj).filter("pred"); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/ProjectMerge.java b/src/main/java/org/qed/RRuleInstances/ProjectMerge.java new file mode 100644 index 0000000..6158b60 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/ProjectMerge.java @@ -0,0 +1,22 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record ProjectMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.proj("inner", "Inner_Type"); + static final String outer = "outer"; + static final String outerType = "Outer_Type"; + + @Override + public RelRN before() { + return source.project(inner).project(outer, outerType); + } + + @Override + public RelRN after() { + return source.project(inner.proj(outer, outerType)); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/PruneEmptyFilter.java b/src/main/java/org/qed/RRuleInstances/PruneEmptyFilter.java new file mode 100644 index 0000000..12931c2 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/PruneEmptyFilter.java @@ -0,0 +1,20 @@ +package org.qed.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyFilter() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN cond = source.pred("filter_cond"); + + @Override + public RelRN before() { + return source.empty().filter(cond); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/PruneEmptyIntersect.java b/src/main/java/org/qed/RRuleInstances/PruneEmptyIntersect.java new file mode 100644 index 0000000..8d9d284 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/PruneEmptyIntersect.java @@ -0,0 +1,19 @@ +package org.qed.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; + +public record PruneEmptyIntersect() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.intersect(false, b.empty()); + } + + @Override + public RelRN after() { + return a.empty().intersect(false, b.empty()); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/PruneEmptyMinus.java b/src/main/java/org/qed/RRuleInstances/PruneEmptyMinus.java new file mode 100644 index 0000000..be9dab9 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/PruneEmptyMinus.java @@ -0,0 +1,19 @@ +package org.qed.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; + +public record PruneEmptyMinus() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.empty().minus(false, b); + } + + @Override + public RelRN after() { + return a.empty(); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/PruneEmptyProject.java b/src/main/java/org/qed/RRuleInstances/PruneEmptyProject.java new file mode 100644 index 0000000..e61b68f --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/PruneEmptyProject.java @@ -0,0 +1,20 @@ +package org.qed.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyProject() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.empty().project(proj); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/PruneEmptyUnion.java b/src/main/java/org/qed/RRuleInstances/PruneEmptyUnion.java new file mode 100644 index 0000000..515a3f7 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/PruneEmptyUnion.java @@ -0,0 +1,19 @@ +package org.qed.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; + +public record PruneEmptyUnion() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.empty().union(false, b.empty()); + } + + @Override + public RelRN after() { + return a.empty(); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java b/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java new file mode 100644 index 0000000..fbdc617 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java @@ -0,0 +1,24 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record SemiJoinFilterTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + static final RexRN filterPred = left.pred("filter"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, joinCond, right).filter(filterPred); + } + + @Override + public RelRN after() { + RelRN leftFiltered = left.filter(filterPred); + return leftFiltered.join(JoinRelType.SEMI, leftFiltered.joinPred("join", right), right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/UnionMerge.java b/src/main/java/org/qed/RRuleInstances/UnionMerge.java new file mode 100644 index 0000000..33a7f61 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/UnionMerge.java @@ -0,0 +1,20 @@ +package org.qed.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RRule; + +public record UnionMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.union(false, b).union(false, c); + } + + @Override + public RelRN after() { + return a.union(false, b, c); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/UnionPullUpConstants.java b/src/main/java/org/qed/RRuleInstances/UnionPullUpConstants.java new file mode 100644 index 0000000..2271f12 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/UnionPullUpConstants.java @@ -0,0 +1,150 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.RelNode; +import org.qed.RelRN; +import org.qed.RRule; +import org.qed.RuleBuilder; +import org.qed.RelType; +import kala.collection.Seq; +import kala.tuple.Tuple; + +public record UnionPullUpConstants() implements RRule { + + static final RelRN left = new SourceTable(); + static final RelRN right = new SourceTable(); + + @Override + public RelRN before() { + var leftProjection = new LeftProjectionWithConstants(left); + var rightProjection = new RightProjectionWithConstants(right); + return new UnionWithConstantColumns(leftProjection, rightProjection); + } + + @Override + public RelRN after() { + var leftProjectionReduced = new LeftProjectionNonConstants(left); + var rightProjectionReduced = new RightProjectionNonConstants(right); + var reducedUnion = new UnionReducedColumns(leftProjectionReduced, rightProjectionReduced); + return new TopProjectionWithConstants(reducedUnion); + } + + public static record SourceTable() implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("Source_Type", true), false), + Tuple.of(RelType.fromString("Source_Type", true), false), + Tuple.of(RelType.fromString("Source_Type", true), false) + )); + + builder.addTable(table); + return builder.scan(table.getName()).build(); + } + } + public static record LeftProjectionWithConstants(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + builder.project( + builder.field(0), + builder.alias(builder.literal("ACTIVE"), "status"), + builder.field(2) + ); + + return builder.build(); + } + } + + public static record RightProjectionWithConstants(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + builder.project( + builder.field(0), + builder.alias(builder.literal("ACTIVE"), "status"), + builder.field(2) + ); + + return builder.build(); + } + } + + public static record UnionWithConstantColumns(RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + builder.push(left.semantics()); + builder.push(right.semantics()); + + builder.union(true, 2); // UNION ALL + + return builder.build(); + } + } + + public static record LeftProjectionNonConstants(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + builder.project( + builder.field(0), + builder.field(2) + ); + + return builder.build(); + } + } + + public static record RightProjectionNonConstants(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + builder.project( + builder.field(0), + builder.field(2) + ); + + return builder.build(); + } + } + + public static record UnionReducedColumns(RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + builder.push(left.semantics()); + builder.push(right.semantics()); + + builder.union(true, 2); + + return builder.build(); + } + } + + public static record TopProjectionWithConstants(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + + builder.project( + builder.field(0), + builder.alias(builder.literal("ACTIVE"), "status"), + builder.field(1) + ); + + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstances/UnionToDistinct.java b/src/main/java/org/qed/RRuleInstances/UnionToDistinct.java new file mode 100644 index 0000000..88bb058 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/UnionToDistinct.java @@ -0,0 +1,62 @@ +package org.qed.RRuleInstances; + +import org.apache.calcite.rel.RelNode; +import org.qed.RelRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record UnionToDistinct() implements RRule { + + static final RelRN left = RelRN.scan("Left", "Source_Type"); + static final RelRN right = RelRN.scan("Right", "Source_Type"); + + @Override + public RelRN before() { + return new DistinctUnion(left, right); + } + + @Override + public RelRN after() { + var unionAll = new UnionAll(left, right); + return new DistinctAggregate(unionAll); + } + + public static record DistinctUnion(RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(left.semantics()); + builder.push(right.semantics()); + builder.union(false, 2); + + return builder.build(); + } + } + + public static record UnionAll(RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + + builder.push(left.semantics()); + builder.push(right.semantics()); + + builder.union(true, 2); + + return builder.build(); + } + } + + public static record DistinctAggregate(RelRN input) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(input.semantics()); + // Group by all fields to remove duplicates + var groupKey = builder.groupKey(builder.field(0)); + builder.aggregate(groupKey); + + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RawPlanner.java b/src/main/java/org/qed/RawPlanner.java new file mode 100644 index 0000000..e56d9a5 --- /dev/null +++ b/src/main/java/org/qed/RawPlanner.java @@ -0,0 +1,240 @@ +package org.qed; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.util.Quoting; +import org.apache.calcite.config.*; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.plan.*; +import org.apache.calcite.plan.volcano.VolcanoPlanner; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.RelCollationTraitDef; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeSystem; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexExecutor; +import org.apache.calcite.runtime.Hook; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlInsert; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.util.SqlOperatorTables; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorImpl; +import org.apache.calcite.sql2rel.RelDecorrelator; +import org.apache.calcite.sql2rel.SqlRexConvertletTable; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.tools.*; +import org.apache.calcite.util.SourceStringReader; +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.io.Reader; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * A copy of the PlannerImpl that disables all rewrite rules. + */ +public class RawPlanner implements RelOptTable.ViewExpander { + private final SqlOperatorTable operatorTable; + private final ImmutableList programs; + private final @Nullable RelOptCostFactory costFactory; + private final Context context; + private final CalciteConnectionConfig connectionConfig; + + /** + * Holds the trait definitions to be registered with planner. May be null. + */ + private final @Nullable ImmutableList traitDefs; + + private final SqlParser.Config parserConfig; + private final SqlValidator.Config sqlValidatorConfig; + private final SqlToRelConverter.Config sqlToRelConverterConfig; + private final SqlRexConvertletTable convertletTable; + // set in STATE_2_READY + private @Nullable + final SchemaPlus defaultSchema; + private @Nullable + final RexExecutor executor; + // set in STATE_1_RESET + @SuppressWarnings("unused") + private boolean open; + private @Nullable JavaTypeFactory typeFactory; + private @Nullable RelOptPlanner planner; + // set in STATE_4_VALIDATE + private @Nullable SqlValidator validator; + + public RawPlanner(SchemaPlus schema) { + var config = generateConfig(schema); + this.costFactory = config.getCostFactory(); + this.defaultSchema = config.getDefaultSchema(); + this.operatorTable = config.getOperatorTable(); + this.programs = config.getPrograms(); + this.parserConfig = config.getParserConfig(); + this.sqlValidatorConfig = config.getSqlValidatorConfig(); + this.sqlToRelConverterConfig = config.getSqlToRelConverterConfig(); + this.traitDefs = config.getTraitDefs(); + this.convertletTable = config.getConvertletTable(); + this.executor = config.getExecutor(); + this.context = config.getContext(); + this.connectionConfig = connConfig(context, parserConfig); + } + + public static FrameworkConfig generateConfig(SchemaPlus schema) { + SqlToRelConverter.Config converterConfig = SqlToRelConverter.config().withRelBuilderConfigTransform( + c -> c.withPushJoinCondition(false).withSimplify(false).withSimplifyValues(false).withBloat(-1) + .withDedupAggregateCalls(false).withPruneInputOfAggregate(false)) + .withDecorrelationEnabled(false).withExpand(false).withTrimUnusedFields(true); + var builderConfig = RelBuilder.Config.DEFAULT.withBloat(-1).withSimplify(false).withSimplifyValues(false); + return Frameworks.newConfigBuilder().defaultSchema(schema) + .parserConfig(SqlParser.Config.DEFAULT.withLex(Lex.MYSQL).withQuoting(Quoting.DOUBLE_QUOTE)) + .sqlToRelConverterConfig(converterConfig).context(Contexts.of(builderConfig)).build(); + } + + private static CalciteConnectionConfig connConfig(Context context, SqlParser.Config parserConfig) { + CalciteConnectionConfigImpl config = + context.maybeUnwrap(CalciteConnectionConfigImpl.class).orElse(CalciteConnectionConfig.DEFAULT); + if (!config.isSet(CalciteConnectionProperty.CASE_SENSITIVE)) { + config = config.set(CalciteConnectionProperty.CASE_SENSITIVE, String.valueOf(parserConfig.caseSensitive())); + } + if (!config.isSet(CalciteConnectionProperty.CONFORMANCE)) { + config = config.set(CalciteConnectionProperty.CONFORMANCE, String.valueOf(parserConfig.conformance())); + } + return config; + } + + private static SchemaPlus rootSchema(SchemaPlus schema) { + for (; ; ) { + SchemaPlus parentSchema = schema.getParentSchema(); + if (parentSchema == null) { + return schema; + } + schema = parentSchema; + } + } + + private void ready() { + RelDataTypeSystem typeSystem = connectionConfig.typeSystem(RelDataTypeSystem.class, RelDataTypeSystem.DEFAULT); + typeFactory = new JavaTypeFactoryImpl(typeSystem); + RelOptPlanner planner = this.planner = new VolcanoPlanner(costFactory, context); + RelOptUtil.registerDefaultRules(planner, connectionConfig.materializationsEnabled(), + Hook.ENABLE_BINDABLE.get(false)); + planner.setExecutor(executor); + + // If user specify own traitDef, instead of default default trait, + // register the trait def specified in traitDefs. + if (this.traitDefs == null) { + planner.addRelTraitDef(ConventionTraitDef.INSTANCE); + if (CalciteSystemProperty.ENABLE_COLLATION_TRAIT.value()) { + planner.addRelTraitDef(RelCollationTraitDef.INSTANCE); + } + } else { + for (RelTraitDef def : this.traitDefs) { + planner.addRelTraitDef(def); + } + } + } + + public SqlNode parse(String sql) throws SqlParseException, ValidationException { + ready(); + Reader reader = new SourceStringReader(sql); + SqlParser parser = SqlParser.create(reader, parserConfig); + SqlNode sqlNode = parser.parseStmt(); + this.validator = createSqlValidator(createCatalogReader()); + try { + return validator.validate(sqlNode); + } catch (RuntimeException e) { + throw new ValidationException(e); + } + } + + private SqlValidator createSqlValidator(CalciteCatalogReader catalogReader) { + final SqlOperatorTable opTab = SqlOperatorTables.chain(operatorTable, catalogReader); + return new RawSqlValidator(opTab, catalogReader, getTypeFactory(), + sqlValidatorConfig.withDefaultNullCollation(connectionConfig.defaultNullCollation()) + .withLenientOperatorLookup(connectionConfig.lenientOperatorLookup()) + .withConformance(connectionConfig.conformance()).withIdentifierExpansion(true)); + } + + private CalciteCatalogReader createCatalogReader() { + SchemaPlus defaultSchema = requireNonNull(this.defaultSchema, "defaultSchema"); + final SchemaPlus rootSchema = rootSchema(defaultSchema); + + return new CalciteCatalogReader(CalciteSchema.from(rootSchema), CalciteSchema.from(defaultSchema).path(null), + getTypeFactory(), connectionConfig); + } + + public JavaTypeFactory getTypeFactory() { + return requireNonNull(typeFactory, "typeFactory"); + } + + public RelNode rel(SqlNode sqlNode) { + final RexBuilder rexBuilder = createRexBuilder(); + final RelOptCluster cluster = RelOptCluster.create(requireNonNull(planner, "planner"), rexBuilder); + final SqlToRelConverter sqlToRelConverter = + new SqlToRelConverter(this, validator, createCatalogReader(), cluster, convertletTable, + sqlToRelConverterConfig); + return sqlToRelConverter.convertQuery(sqlNode, false, true).project(); + } + + private RexBuilder createRexBuilder() { + return new RexBuilder(getTypeFactory()); + } + + @Override + public RelRoot expandView(RelDataType rowType, String queryString, List schemaPath, + @Nullable List viewPath) { + RelOptPlanner planner = this.planner; + if (planner == null) { + ready(); + planner = requireNonNull(this.planner, "planner"); + } + SqlParser parser = SqlParser.create(queryString, parserConfig); + SqlNode sqlNode; + try { + sqlNode = parser.parseQuery(); + } catch (SqlParseException e) { + throw new RuntimeException("parse failed", e); + } + + final CalciteCatalogReader catalogReader = createCatalogReader().withSchemaPath(schemaPath); + final SqlValidator validator = createSqlValidator(catalogReader); + + final RexBuilder rexBuilder = createRexBuilder(); + final RelOptCluster cluster = RelOptCluster.create(planner, rexBuilder); + final SqlToRelConverter sqlToRelConverter = + new SqlToRelConverter(this, validator, catalogReader, cluster, convertletTable, + sqlToRelConverterConfig); + + final RelRoot root = sqlToRelConverter.convertQuery(sqlNode, true, false); + final RelRoot root2 = root.withRel(sqlToRelConverter.flattenTypes(root.rel, true)); + final RelBuilder relBuilder = sqlToRelConverterConfig.getRelBuilderFactory().create(cluster, null); + return root2.withRel(RelDecorrelator.decorrelateQuery(root.rel, relBuilder)); + } +} + +class RawSqlValidator extends SqlValidatorImpl { + + RawSqlValidator(SqlOperatorTable opTab, CalciteCatalogReader catalogReader, JavaTypeFactory typeFactory, + Config config) { + super(opTab, catalogReader, typeFactory, config); + } + + @Override + protected RelDataType getLogicalSourceRowType(RelDataType sourceRowType, SqlInsert insert) { + final RelDataType superType = super.getLogicalSourceRowType(sourceRowType, insert); + return ((JavaTypeFactory) typeFactory).toSql(superType); + } + + @Override + protected RelDataType getLogicalTargetRowType(RelDataType targetRowType, SqlInsert insert) { + final RelDataType superType = super.getLogicalTargetRowType(targetRowType, insert); + return ((JavaTypeFactory) typeFactory).toSql(superType); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RelFolder.java b/src/main/java/org/qed/RelFolder.java new file mode 100644 index 0000000..f7a1efe --- /dev/null +++ b/src/main/java/org/qed/RelFolder.java @@ -0,0 +1,26 @@ +package org.qed; + +import kala.collection.Seq; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; + +import java.util.function.Function; + +public interface RelFolder extends Function { + RelNode post(RelNode rel); + + default RelNode apply(RelNode rel) { + class RexFolder extends RexShuttle { + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + return super.visitSubQuery(subQuery.clone(RelFolder.this.apply(subQuery.rel))); + } + } + var newRel = rel.accept(new RexFolder()); + var inputs = Seq.from(newRel.getInputs()).map(this).asJava(); + return post(newRel.copy(newRel.getTraitSet(), inputs)); + } + +} diff --git a/src/main/java/org/qed/RelJSONShuttle.java b/src/main/java/org/qed/RelJSONShuttle.java new file mode 100644 index 0000000..1472f09 --- /dev/null +++ b/src/main/java/org/qed/RelJSONShuttle.java @@ -0,0 +1,364 @@ +package org.qed; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.*; +import kala.collection.Map; +import kala.collection.Seq; +import kala.collection.Set; +import kala.collection.immutable.ImmutableSeq; +import kala.control.Result; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.*; +import org.apache.calcite.rel.type.*; +import org.apache.calcite.rex.*; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.util.ImmutableBitSet; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +public record RelJSONShuttle(Env env) { + private final static ObjectMapper mapper = new ObjectMapper(); + + private static ArrayNode array(Seq objs) { + return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); + } + + + private static Result, String> array(JsonNode jsonNode, String field) { + var arr = jsonNode.get(field); + if (arr == null || !arr.isArray()) { + return Result.err(String.format("Missing array field %s in:\n%s", field, jsonNode.toPrettyString())); + } + return Result.ok(ImmutableSeq.from(arr.elements())); + } + + private static ObjectNode object(Map fields) { + return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); + } + + private static Result object(JsonNode jsonNode, String field) { + var obj = jsonNode.get(field); + if (obj == null) { + return Result.err(String.format("Missing object field %s in:\n%s", field, jsonNode.toPrettyString())); + } + return Result.ok(obj); + } + + private static BooleanNode bool(boolean b) { + return b ? BooleanNode.TRUE : BooleanNode.FALSE; + } + + private static T unwrap(Result res) throws Exception { + if (res.isErr()) { + throw new Exception(res.getErr()); + } + return res.get(); + } + + public static void main(String[] args) throws IOException { + var res = RelJSONShuttle.deserializeFromJson(Paths.get("ElevatedRules/filterProjectTranspose.json")); + if (res.isErr()) { + System.out.println(res.getErr()); + } else { + res.get().forEach(r -> System.out.println(r.explain())); + } + } + + public static void serializeToJson(List relNodes, Path path) throws IOException { + var shuttle = new RelJSONShuttle(Env.empty()); + var helps = array(Seq.from(relNodes).map(rel -> new TextNode(rel.explain()))); + var queries = array(Seq.from(relNodes).map(shuttle::serialize)); + + var tables = shuttle.env.tables(); + var schemas = array(tables.map(table -> object(Map.of( + "name", new TextNode(table.getName()), + "fields", array(table.getColumnNames().map(TextNode::new)), + "types", array(table.getColumnTypes().map(type -> new TextNode(type.toString()))), + "nullable", array(table.getColumnTypes().map(RelDataType::isNullable).map(RelJSONShuttle::bool)), + "key", array(Seq.from(table.getKeys().map(key -> array(Seq.from(key).map(IntNode::new))))), + "guaranteed", array(table.getConstraints() + .map(check -> new RexJSONVisitor(shuttle.env.advanced(table.getColumnNames().size())).serialize(check)).toImmutableSeq()) + )))); + + var main = object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + mapper.writerWithDefaultPrettyPrinter().writeValue(path.toFile(), main); + } + + public static Result, String> deserializeFromJson(Path path) throws IOException { + var node = mapper.readTree(path.toFile()); + var env = Env.empty(); + var tables = array(node, "schemas").flatMap(schemas -> { + var collected = ImmutableSeq.empty(); + for (var schema : schemas) { + try { + var tys = unwrap(array(schema, "types")); + var nbs = unwrap(array(schema, "nullable")); + var nm = unwrap(object(schema, "name")); + var fds = unwrap(array(schema, "fields")).map(JsonNode::asText); + var kys = unwrap(array(schema, "key")); + var kgs = Set.from(kys.map(kg -> ImmutableBitSet.of(Seq.from(kg.elements()).map(JsonNode::asInt)))); + if (tys.size() != nbs.size()) { + return Result.err("Expecting corresponding types and nullabilities"); + } + var sts = tys.zip(nbs).map(tn -> (RelDataType) RelType.fromString(tn.component1().asText(), + tn.component2().asBoolean())); + collected = collected.appended(new QedTable(nm.asText(), fds, sts, kgs, Set.empty())); + } catch (Exception e) { + return Result.err( + String.format("Broken table schemas: %s in\n%s", e.getMessage(), schema.toPrettyString())); + } + } + return Result.ok(collected); + }); + if (tables.isErr()) { + return Result.err(tables.getErr()); + } + env.tables().appendAll(tables.get()); + var queries = array(node, "queries"); + if (queries.isErr()) { + return Result.err(queries.getErr()); + } + var shuttle = new RelJSONShuttle(env); + return queries.get().map(q -> { + var builder = RuleBuilder.create(); + tables.get().forEach(builder::addTable); + return shuttle.deserialize(builder, q); + }).foldLeft(Result.ok(ImmutableSeq.empty()), (qs, qb) -> qs.flatMap(s -> qb.map(b -> s.appended(b.build())))); + } + + public JsonNode serialize(RelNode rel) { + return switch (rel) { + case TableScan scan -> + object(Map.of("scan", new IntNode(env.resolve(scan.getTable().unwrap(QedTable.class))))); + case LogicalValues values -> { + var visitor = new RexJSONVisitor(env); + var schema = array(Seq.from(values.getRowType().getFieldList()) + .map(field -> new TextNode(field.getType().toString()))); + var records = array(Seq.from(values.getTuples()) + .map(tuple -> array(Seq.from(tuple).map(visitor::serialize)))); + yield object(Map.of("values", object(Map.of("schema", schema, "content", records)))); + } + case LogicalFilter filter -> { + var visitor = new RexJSONVisitor(env.advanced(filter.getInput().getRowType().getFieldCount()) + .recorded(filter.getVariablesSet())); + yield object(Map.of("filter", + object(Map.of("condition", visitor.serialize(filter.getCondition()), "source", + serialize(filter.getInput()))))); + } + case LogicalProject project -> { + var visitor = new RexJSONVisitor(env.advanced(project.getInput().getRowType().getFieldCount()) + .recorded(project.getVariablesSet())); + var targets = array(Seq.from(project.getProjects()).map(visitor::serialize)); + yield object( + Map.of("project", object(Map.of("target", targets, "source", serialize(project.getInput()))))); + } + case LogicalJoin join -> { + var left = join.getLeft(); + var right = join.getRight(); + var visitor = new RexJSONVisitor( + env.advanced(left.getRowType().getFieldCount() + right.getRowType().getFieldCount()) + .recorded(join.getVariablesSet())); + yield object(Map.of("join", + object(Map.of("kind", new TextNode(join.getJoinType().toString()), "condition", + visitor.serialize(join.getCondition()), "left", serialize(left), "right", + serialize(right))))); + } + case LogicalCorrelate correlate -> { + var rightShuttle = new RelJSONShuttle(env.advanced(correlate.getLeft().getRowType().getFieldCount()) + .recorded(correlate.getVariablesSet()).advanced(0)); + yield object(Map.of("correlate", + array(Seq.of(serialize(correlate.getLeft()), rightShuttle.serialize(correlate.getRight()))))); + } + case LogicalAggregate aggregate -> { + var groupCount = aggregate.getGroupCount(); + var level = env.base(); + var types = Seq.from(aggregate.getInput().getRowType().getFieldList()) + .map(type -> new TextNode(type.getType().toString())); + var keyCols = array(Seq.from(aggregate.getGroupSet()) + .map(key -> object(Map.of("column", new IntNode(level + key), "type", types.get(key))))); + var keys = object(Map.of("project", + object(Map.of("target", keyCols, "source", serialize(aggregate.getInput()))))); + var conditions = array(Seq.from(aggregate.getGroupSet()).mapIndexed((i, key) -> { + var type = types.get(key); + var leftCol = object(Map.of("column", new IntNode(level + i), "type", type)); + var rightCol = object(Map.of("column", new IntNode(level + groupCount + key), "type", type)); + return object( + Map.of("operator", new TextNode("<=>"), "operand", array(Seq.of(leftCol, rightCol)), "type", + new TextNode("BOOLEAN"))); + })); + var condition = object(Map.of("operator", new TextNode("AND"), "operand", conditions, "type", + new TextNode("BOOLEAN"))); + var aggs = array(Seq.from(aggregate.getAggCallList()).map(call -> object( + Map.of("operator", new TextNode(call.getAggregation().getName()), "operand", + array(Seq.from(call.getArgList()).map(target -> object( + Map.of("column", new IntNode(level + groupCount + target), "type", + types.get(target))))), "distinct", bool(call.isDistinct()), + "ignoreNulls", bool(call.ignoreNulls()), "type", + new TextNode(call.getType().toString()))))); + var aggregated = object(Map.of("aggregate", object(Map.of("function", aggs, "source", + object(Map.of("filter", object(Map.of("condition", condition, "source", + new RelJSONShuttle(env.lifted(groupCount)).serialize(aggregate.getInput()))))))))); + yield object(Map.of("distinct", object(Map.of("correlate", array(Seq.of(keys, aggregated)))))); + } + case LogicalUnion union -> { + var result = object(Map.of("union", array(Seq.from(union.getInputs()).map(this::serialize)))); + yield union.all ? result : object(Map.of("distinct", result)); + } + case LogicalIntersect intersect when !intersect.all -> + object(Map.of("intersect", array(Seq.from(intersect.getInputs()).map(this::serialize)))); + case LogicalMinus minus when !minus.all -> + object(Map.of("except", array(Seq.from(minus.getInputs()).map(this::serialize)))); + case LogicalSort sort -> { + var types = Seq.from(sort.getInput().getRowType().getFieldList()) + .map(type -> new TextNode(type.getType().toString())); + var collations = array(Seq.from(sort.collation.getFieldCollations()).map(collation -> { + var index = collation.getFieldIndex(); + return array(Seq.of(new IntNode(index), types.get(index), new TextNode(collation.shortString()))); + })); + var args = object(Map.of("collation", collations, "source", serialize(sort.getInput()))); + var visitor = new RexJSONVisitor(env.advanced(sort.getInput().getRowType().getFieldCount())); + if (sort.offset != null) { + args.set("offset", visitor.serialize(sort.offset)); + } + if (sort.fetch != null) { + args.set("limit", visitor.serialize(sort.fetch)); + } + yield object(Map.of("sort", args)); + } + default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); + }; + } + + public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { + var entry = jsonNode.fields().next(); + var kind = entry.getKey(); + var content = entry.getValue(); + return switch (kind) { + case String k when k.equals("scan") -> { + if (content.isInt() && 0 <= content.asInt() && content.asInt() < env.tables().size()) { + builder.scan(env.tables().get(content.asInt()).getName()); + yield Result.ok(builder); + } + yield Result.err(String.format("Missing table with index %s", content.toPrettyString())); + } + case String k when k.equals("values") -> { + try { + var et = unwrap(array(content, "schema")); + var rt = new RelRecordType(StructKind.FULLY_QUALIFIED, et.mapIndexed( + (i, t) -> (RelDataTypeField) new RelDataTypeFieldImpl(String.format("VALUES-%s", i), i, + RelType.fromString(t.asText(), true))).asJava()); + var vs = unwrap(array(content, "content")); + var vals = ImmutableSeq.>empty(); + for (var v : vs) { + var val = ImmutableSeq.empty(); + if (!v.isArray()) { + yield Result.err("Expecting tuple (JSON list) as value"); + } + for (var jl : Seq.from(v.elements())) { + var l = unwrap(new RexJSONVisitor(env).deserialize(builder, jl)); + if (l instanceof RexLiteral) { + val = val.appended((RexLiteral) l); + } else { + yield Result.err("Expecting literal expression"); + } + } + vals = vals.appended(val.asJava()); + } + builder.values(vals.asJava(), rt); + yield Result.ok(builder); + } catch (Exception e) { + yield Result.err(e.getMessage()); + } + } + case String k when k.equals("filter") -> { + try { + var cond = unwrap(object(content, "condition")); + var source = unwrap(object(content, "source")); + var bs = unwrap(deserialize(builder, source)); + var c = unwrap(new RexJSONVisitor(env).deserialize(builder, cond)); + bs.filter(c); + yield Result.ok(bs); + } catch (Exception e) { + yield Result.err(e.getMessage()); + } + } + case String k when k.equals("project") -> { + try { + var target = unwrap(array(content, "target")); + var source = unwrap(object(content, "source")); + var bs = unwrap(deserialize(builder, source)); + var ps = target.mapChecked(t -> unwrap(new RexJSONVisitor(env).deserialize(builder, t))); + bs.project(ps); + yield Result.ok(bs); + } catch (Exception e) { + yield Result.err(e.getMessage()); + } + } + case String k when k.equals("join") -> Result.err("Not implemented yet"); + case String k when k.equals("correlate") -> Result.err("Not implemented yet"); + default -> Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); + }; + } + + public record RexJSONVisitor(Env env) { + public JsonNode serialize(RexNode rex) { + return switch (rex) { + case RexInputRef inputRef -> + object(Map.of("column", new IntNode(inputRef.getIndex() + env.base()), "type", + new TextNode(inputRef.getType().toString()))); + case RexLiteral literal -> object(Map.of("operator", + new TextNode(literal.getValue() == null ? "NULL" : literal.getValue().toString()), "operand", + array(Seq.empty()), "type", new TextNode(literal.getType().toString()))); + case RexSubQuery subQuery -> + object(Map.of("operator", new TextNode(subQuery.getOperator().toString()), "operand", + array(Seq.from(subQuery.getOperands()).map(this::serialize)), "query", + new RelJSONShuttle(env.advanced(0)).serialize(subQuery.rel), "type", + new TextNode(subQuery.getType().toString()))); + case RexCall call -> object(Map.of("operator", new TextNode(call.getOperator().toString()), "operand", + array(Seq.from(call.getOperands()).map(this::serialize)), "type", + new TextNode(call.getType().toString()))); + case RexFieldAccess fieldAccess -> object(Map.of("column", new IntNode( + fieldAccess.getField().getIndex() + + env.resolve(((RexCorrelVariable) fieldAccess.getReferenceExpr()).id)), "type", + new TextNode(fieldAccess.getType().toString()))); + default -> throw new RuntimeException("Not implemented: " + rex.getKind()); + }; + } + + public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { + if (jsonNode.has("column") && jsonNode.get("column").isInt()) { + // WARNING: THIS IS WRONG! NO ENVIRONMENT CONSIDERED! + return Result.ok(builder.field(jsonNode.get("column").asInt())); + } else if (jsonNode.has("operator") && jsonNode.get("operator").isTextual()) { + var op = jsonNode.get("operator").asText(); + try { + var args = unwrap(array(jsonNode, "operand")); + var ty = RelType.fromString(unwrap(object(jsonNode, "type")).asText(), true); + if (args.isEmpty()) { + return Result.ok(RexLiteral.fromJdbcString(ty, ty.getSqlTypeName(), op)); + } else { + var fields = args.mapChecked(expr -> unwrap(deserialize(builder, expr))); + for (var refl : Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) + .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && + java.lang.reflect.Modifier.isStatic(f.getModifiers()))) { + var mist = refl.get(null); + if (mist instanceof SqlOperator sqlOperator && sqlOperator.getName().equals(op)) { + return Result.ok(builder.call(sqlOperator, fields)); + } + } + return Result.ok(builder.call(builder.genericProjectionOp(op, ty), fields)); + } + } catch (Exception e) { + return Result.err(e.getMessage()); + } + } + return Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RelPruner.java b/src/main/java/org/qed/RelPruner.java new file mode 100644 index 0000000..91bfb40 --- /dev/null +++ b/src/main/java/org/qed/RelPruner.java @@ -0,0 +1,124 @@ +package org.qed; + +import kala.collection.Seq; +import kala.collection.Set; +import kala.collection.immutable.ImmutableMap; +import kala.collection.immutable.ImmutableSeq; +import kala.collection.immutable.ImmutableSet; +import kala.collection.mutable.MutableHashMap; +import kala.collection.mutable.MutableMap; +import kala.tuple.Tuple; +import kala.tuple.Tuple2; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.prepare.RelOptTableImpl; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalTableScan; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelRecordType; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; +import org.apache.calcite.util.ImmutableBitSet; + +public class RelPruner implements RelFolder { + private final MutableMap> cache = MutableMap.create(); + ImmutableMap> usages; + + public RelPruner(ImmutableMap> usages) { + this.usages = usages; + } + + public RelNode post(RelNode rel) { + return rel; + } + + private LogicalTableScan prune(LogicalTableScan scan) { + var table = scan.getTable(); + QedTable cosTable; + RelDataType rowType; + if (!cache.containsKey(table)) { + var fieldList = table.getRowType().getFieldList(); + var fields = usages.get(RelScanner.getName(scan)).toSeq().sorted(); + var columns = ImmutableMap.from(fields.map(table.getRowType().getFieldNames()::get) + .zip(fields.map(i -> fieldList.get(i).getType()))); + var keys = table.getKeys() == null ? ImmutableSet.empty() : ImmutableSet.from( + Seq.from(table.getKeys()).filter(ks -> ImmutableSet.from(ks).removedAll(fields).isEmpty()) + .map(ks -> ImmutableBitSet.of(ImmutableSet.from(ks).map(fields::indexOf)))); + var qName = table.getQualifiedName(); + cosTable = new QedTable(qName.get(qName.size() - 1), columns, keys, Set.empty()); + rowType = new RelRecordType(fields.map(fieldList::get).asJava()); + cache.put(table, Tuple.of(cosTable, rowType)); + } else { + var p = cache.get(table); + cosTable = p.component1(); + rowType = p.component2(); + } + var t = RelOptTableImpl.create(table.getRelOptSchema(), rowType, table.getQualifiedName(), cosTable, + table::getExpression); + return LogicalTableScan.create(scan.getCluster(), t, scan.getHints()); + } + + @Override + public RelNode apply(RelNode rel) { + return switch (rel) { + case LogicalProject project when (project.getInput() instanceof LogicalTableScan r) && + Seq.from(project.getProjects()).allMatch(col -> col instanceof RexInputRef) -> { + var fin = usages.get(RelScanner.getName(r)).toSeq().sorted(); + var ids = Seq.from(project.getProjects()).map(ref -> ((RexInputRef) ref).getIndex()); + var scan = prune(r); + if (fin.sameElements(ids)) { + yield scan; + } + ImmutableSeq fields = ids.zip(project.getProjects()) + .map(p -> new RexInputRef(fin.indexOf(p.component1()), p.component2().getType())); + yield project.copy(project.getTraitSet(), scan, fields.asJava(), project.getRowType()); + } + case LogicalTableScan r -> { + if (!usages.containsKey(RelScanner.getName(r)) || !usages.get(RelScanner.getName(r)).toSeq().sorted() + .sameElements(ImmutableSeq.fill(r.getTable().getRowType().getFieldCount(), i -> i))) { + throw new IllegalStateException("Illegal raw occurrence of TableScan"); + } + yield prune(r); + } + default -> RelFolder.super.apply(rel); + }; + } +} + +record RelScanner(MutableMap> usages) { + public RelScanner() { + this(new MutableHashMap<>()); + } + + public static String getName(LogicalTableScan tableScan) { + var qName = tableScan.getTable().getQualifiedName(); + return qName.get(qName.size() - 1); + } + + public void scan(RelNode rel) { + rel.accept(new RexScanner()); + switch (rel) { + case LogicalProject project when (project.getInput() instanceof LogicalTableScan r) && + Seq.from(project.getProjects()).allMatch(col -> col instanceof RexInputRef) -> { + var ids = Seq.from(project.getProjects()).map(col -> ((RexInputRef) col).getIndex()); + var name = getName(r); + usages.put(name, usages.getOrDefault(name, ImmutableSet.empty()).addedAll(ids)); + } + case LogicalTableScan r -> { + var ids = ImmutableSet.from(ImmutableSeq.fill(r.getTable().getRowType().getFieldCount(), i -> i)); + usages.put(getName(r), ids); + } + default -> rel.getInputs().forEach(this::scan); + } + } + + private class RexScanner extends RexShuttle { + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + scan(subQuery.rel); + return super.visitSubQuery(subQuery); + } + } +} diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java new file mode 100644 index 0000000..1debebb --- /dev/null +++ b/src/main/java/org/qed/RelRN.java @@ -0,0 +1,332 @@ +package org.qed; + +import kala.collection.Seq; +import kala.collection.Set; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilder.AggCall; +import org.apache.calcite.util.ImmutableBitSet; +import org.qed.RelRN.JoinWithSeparateConds; +import org.qed.RelRN.Scan; + +import java.util.Arrays; +import java.util.stream.IntStream; + +public interface RelRN { + static Scan scan(String id, RelType.VarType ty, boolean unique) { + return new Scan(id, ty, unique); + } + + static Scan scan(String id, String typeName) { + return scan(id, RexRN.varType(typeName, true), false); + } + + RelNode semantics(); + + default RexRN field(int ordinal) { + return new RexRN.Field(ordinal, this); + } + + default Seq fields(int... ordinals) { + return Seq.from(Arrays.stream(ordinals).iterator()).map(this::field); + } + + default Seq fields() { + return fields(IntStream.range(0, semantics().getRowType().getFieldCount()).toArray()); + } + + default RexRN joinField(int ordinal, RelRN right) { + return new RexRN.JoinField(ordinal, this, right); + } + + default Seq joinFields(RelRN right, int... ordinals) { + return Seq.from(Arrays.stream(ordinals).iterator()).map(i -> joinField(i, right)); + } + + default Seq joinFields(RelRN right) { + return joinFields(right, IntStream.range(0, + semantics().getRowType().getFieldCount() + right.semantics().getRowType().getFieldCount()).toArray()); + } + + default RexRN.Pred pred(SqlOperator op) { + return new RexRN.Pred(op, fields()); + } + + default RexRN.Pred pred(String name) { + return pred(RuleBuilder.create().genericPredicateOp(name, true)); + } + + default RexRN.Pred joinPred(SqlOperator op, RelRN right) { + return new RexRN.Pred(op, joinFields(right)); + } + + default RexRN.Pred joinPred(String name, RelRN right) { + return joinPred(RuleBuilder.create().genericPredicateOp(name, true), right); + } + + default RexRN.Proj proj(SqlOperator op) { + return new RexRN.Proj(op, fields()); + } + + default RexRN.Proj proj(String name, String type_name) { + return proj(RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(type_name, true))); + } + + default Filter filter(RexRN cond) { + return new Filter(cond, this); + } + + default Filter filter(String name) { + return filter(pred(name)); + } + + default Project project(RexRN proj) { + return new Project(proj, this); + } + + default Project project(String name, String type_name) { + return project(proj(name, type_name)); + } + + default Join join(Join.JoinType ty, RexRN cond, RelRN right) { + return new Join(ty, cond, this, right); + } + + default Join join(JoinRelType ty, RexRN cond, RelRN right) { + return join(new Join.JoinType.ConcreteJoinType(ty), cond, right); + } + + default JoinWithSeparateConds joinWithSeparateConds(JoinRelType ty, RexRN cond, RelRN right) { + return new JoinWithSeparateConds(new Join.JoinType.ConcreteJoinType(ty), cond, this, right); + } + + default JoinWithPushedConds joinWithPushedConds(JoinRelType ty, RexRN cond, RelRN right) { + return new JoinWithPushedConds(new Join.JoinType.ConcreteJoinType(ty), cond, this, right); + } + + default Join join(JoinRelType ty, String name, RelRN right) {return join(ty, joinPred(name, right), right);} + + default Union union(boolean all, RelRN... sources) { + return new Union(all, Seq.of(this).appendedAll(sources)); + } + + default Intersect intersect(boolean all, RelRN... sources) { + return new Intersect(all, Seq.of(this).appendedAll(sources)); + } + + default Minus minus (boolean all, RelRN... sources) { + return new Minus(all, Seq.of(this).appendedAll(sources)); + } + + default Empty empty() { + return new Empty(this); + } + + + record Scan(String name, RelType.VarType ty, boolean unique) implements RelRN { + + @Override + public RelNode semantics() { + var table = new QedTable(name, Seq.of("col-" + name), Seq.of(ty), unique ? + Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); + return RuleBuilder.create().addTable(table).scan(name).build(); + } + } + + record Filter(RexRN cond, RelRN source) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(source.semantics()).filter(cond.semantics()).build(); + } + } + + record Project(RexRN map, RelRN source) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(source.semantics()).project(map.semantics()).build(); + } + } + + record Join(Join.JoinType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).join(ty.semantics(), + cond.semantics()).build(); + } + + @Override + public RexRN field(int ordinal) { + return new RexRN.JoinField(ordinal, left, right); + } + + public interface JoinType { + JoinRelType semantics(); + + record ConcreteJoinType(JoinRelType type) implements JoinType { + @Override + public JoinRelType semantics() { + return type; + } + } + + record MetaJoinType(String name) implements JoinType { + @Override + public JoinRelType semantics() { + return JoinRelType.INNER; + } + } + } + } + + record JoinWithSeparateConds(Join.JoinType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).join(ty.semantics(), + cond.semantics()).build(); + } + + @Override + public RexRN field(int ordinal) { + return new RexRN.JoinField(ordinal, left, right); + } + + public interface JoinType { + JoinRelType semantics(); + + record ConcreteJoinType(JoinRelType type) implements JoinType { + @Override + public JoinRelType semantics() { + return type; + } + } + + record MetaJoinType(String name) implements JoinType { + @Override + public JoinRelType semantics() { + return JoinRelType.INNER; + } + } + } + } + + record JoinWithPushedConds(Join.JoinType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).join(ty.semantics(), + cond.semantics()).build(); + } + + @Override + public RexRN field(int ordinal) { + return new RexRN.JoinField(ordinal, left, right); + } + + public interface JoinType { + JoinRelType semantics(); + + record ConcreteJoinType(JoinRelType type) implements JoinType { + @Override + public JoinRelType semantics() { + return type; + } + } + + record MetaJoinType(String name) implements JoinType { + @Override + public JoinRelType semantics() { + return JoinRelType.INNER; + } + } + } + } + + record Union(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).union(all, sources.size()).build(); + } + } + + record Intersect(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).intersect(all, sources.size()).build(); + } + } + + record Minus(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).minus(all, sources.size()).build(); + } + } + + record Empty(RelRN sourceType) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().values(sourceType.semantics().getRowType()).build(); + } + } + + default Aggregate aggregate(RexRN groupName, AggCall aggCall) { + return new Aggregate(this, Seq.of(groupName), Seq.of(aggCall)); + } + + default Aggregate aggregate(String groupName, String aggName) { + return aggregate(groupBy(groupName), aggCall(aggName)); + } + + + default RexRN.GroupBy groupBy(String name) { + return new RexRN.GroupBy( + RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(name + "_type", true)), + fields() + ); + } + + default RexRN.GroupBy groupBy(SqlOperator op) { + return new RexRN.GroupBy(op, fields()); + } + + default AggCall aggCall(String name) { + return new AggCall( + name, + RuleBuilder.create().genericAggregateOp(name, new RelType.VarType(name + "_type", true)), + false, + new RelType.VarType(name + "_type", true), + fields() + ); + } + + + record AggCall(String name, SqlAggFunction operator, boolean distinct, RelType type, Seq operands) { + public AggCall(String name, boolean distinct, RelType type, Seq operands){ + this(name, null, distinct, type, operands); + } + } + + record Aggregate(RelRN source, Seq groupSet, Seq aggCalls) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(source.semantics()); + + var groupKey = builder.groupKey(groupSet.map(RexRN::semantics)); + var calls = aggCalls.map(agg -> { + var aggFunc = agg.operator() != null ? agg.operator() : builder.genericAggregateOp(agg.name(), agg.type()); + return builder.aggregateCall(aggFunc, agg.distinct(), null, agg.name(), agg.operands().map(RexRN::semantics).asJava()); + }); + + return builder.aggregate(groupKey, calls).build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RelRacketShuttle.java b/src/main/java/org/qed/RelRacketShuttle.java new file mode 100644 index 0000000..1bf4b9f --- /dev/null +++ b/src/main/java/org/qed/RelRacketShuttle.java @@ -0,0 +1,161 @@ +package org.qed; + +import kala.collection.Seq; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.*; +import org.apache.calcite.rex.*; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.util.NlsString; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +public record RelRacketShuttle(Env env) { + public static void dumpToRacket(List relNodes, Path path) throws IOException { + assert relNodes.size() == 2; + var env = Env.empty(); + var rels = Seq.from(relNodes).mapIndexed((i, rel) -> SExpr.def("r" + i, new RelRacketShuttle(env).visit(rel))); + Seq tabs = env.tables().toSeq().map(t -> { + var table = t.unwrap(QedTable.class); + assert table != null; + var fullName = table.getName(); + Seq fields = table.getColumnNames().map(SExpr::string); + return SExpr.app("table-info", SExpr.string(fullName), SExpr.app("list", fields)); + }); + var tables = SExpr.def("tables", SExpr.app("list", tabs)); + var output = """ + #lang cosette + + %s + + %s + + (solve-relations r0 r1 tables println) + """.formatted(rels.joinToString("\n"), tables); + Files.write(path, output.getBytes()); + } + + public SExpr visit(RelNode rel) { + return switch (rel) { + case TableScan scan -> + SExpr.app("r-scan", SExpr.integer(env.resolve(scan.getTable().unwrap(QedTable.class)))); + case LogicalValues values -> { + var visitor = new RexRacketVisitor(env); + Seq tuples = Seq.from(values.getTuples()) + .map(tuple -> SExpr.app("list", Seq.from(tuple).map(visitor::visit))); + yield SExpr.app("r-values", SExpr.integer(values.getRowType().getFieldCount()), + SExpr.app("list", tuples)); + } + case LogicalFilter filter -> { + var condition = new RexRacketVisitor(env.advanced(filter.getInput().getRowType().getFieldCount()) + .recorded(filter.getVariablesSet())).visit(filter.getCondition()); + yield SExpr.app("r-filter", condition, visit(filter.getInput())); + } + case LogicalProject project -> { + var visitor = new RexRacketVisitor(env.advanced(project.getInput().getRowType().getFieldCount()) + .recorded(project.getVariablesSet())); + var targets = Seq.from(project.getProjects()).map(visitor::visit); + yield SExpr.app("r-project", SExpr.app("list", targets), visit(project.getInput())); + } + case LogicalJoin join -> { + var left = join.getLeft(); + var right = join.getRight(); + var visitor = new RexRacketVisitor( + env.advanced(left.getRowType().getFieldCount() + right.getRowType().getFieldCount()) + .recorded(join.getVariablesSet())); + var kind = SExpr.quoted(switch (join.getJoinType()) { + case INNER -> "inner"; + case LEFT -> "left"; + case RIGHT -> "right"; + case FULL -> "full"; + default -> + throw new UnsupportedOperationException("Not supported join type: " + join.getJoinType()); + }); + yield SExpr.app("r-join", kind, visitor.visit(join.getCondition()), visit(left), visit(right)); + } + case LogicalAggregate aggregate -> { + var aggs = SExpr.app("list", Seq.from(aggregate.getAggCallList()).map(agg -> { + var name = SExpr.quoted(switch (agg.getAggregation().getName().toLowerCase()) { + case "count" -> "aggr-count" + (agg.isDistinct() ? "-distinct" : "") + + (agg.ignoreNulls() ? "-all" : ""); + case "sum" -> "aggr-sum"; + case "max" -> "aggr-max"; + case "min" -> "aggr-min"; + case "avg" -> "aggr-avg"; + default -> throw new UnsupportedOperationException( + "Not supported aggregation function: " + agg.getAggregation().getName()); + }); + var cols = SExpr.app("list", Seq.from(agg.getArgList()).map(SExpr::integer)); + return SExpr.app("v-agg", name, cols); + })); + var groupSet = SExpr.app("list", Seq.from(aggregate.getGroupSet().asSet()).map(SExpr::integer)); + yield SExpr.app("r-agg", aggs, groupSet, visit(aggregate.getInput())); + } + case LogicalUnion union -> { + // TODO: Handle non-all union + yield SExpr.app("r-union", SExpr.app("list", Seq.from(union.getInputs()).map(this::visit))); + } + case LogicalSort sort -> { + // TODO: Properly handle sorting. + yield this.visit(sort.getInput()); + } + default -> throw new UnsupportedOperationException("Not implemented: " + rel.getRelTypeName()); + }; + } + + public record RexRacketVisitor(Env env) { + public SExpr visit(RexNode rex) { + return switch (rex) { + case RexInputRef inputRef -> SExpr.app("v-var", SExpr.integer(env.base() + inputRef.getIndex())); + case RexLiteral literal -> { + var val = switch (literal.getValue()) { + case null -> SExpr.quoted("null"); + case Float v -> SExpr.real(v); + case Double v -> SExpr.real(v); + case BigDecimal v -> { + try { + yield SExpr.integer(v.longValueExact()); + } catch (ArithmeticException e) { + yield SExpr.real(v.doubleValue()); + } + } + case Integer v -> SExpr.integer(v); + case Long v -> SExpr.integer(v); + case Boolean v -> SExpr.bool(v); + case String v -> SExpr.string(v); + case NlsString v -> SExpr.string(v.toString()); + default -> throw new UnsupportedOperationException("Unsupported literal: " + literal); + }; + yield SExpr.app("v-op", val, SExpr.app("list")); + } + case RexSubQuery subQuery -> { + var name = switch (subQuery.getOperator().getKind()) { + case EXISTS -> "exists"; + case UNIQUE -> "unique"; + case IN -> "in"; + case SqlKind kind -> + throw new UnsupportedOperationException("Unsupported subquery operation: " + kind); + }; + var operands = Seq.from(subQuery.getOperands()).map(this::visit); + var rel = new RelRacketShuttle(env.advanced(0)).visit(subQuery.rel); + yield SExpr.app("v-hop", SExpr.quoted(name), SExpr.app("list", operands), rel); + } + case RexCall call -> { + var operands = Seq.from(call.getOperands()).map(this::visit); + yield SExpr.app("v-op", SExpr.quoted(call.op.getName().replace(" ", "-").toLowerCase()), + SExpr.app("list", operands)); + } + case RexFieldAccess access -> { + var id = SExpr.integer(env.resolve(((RexCorrelVariable) access.getReferenceExpr()).id) + + access.getField().getIndex()); + yield SExpr.app("v-var", id); + } + default -> throw new UnsupportedOperationException("Unsupported value: " + rex.getKind()); + }; + } + } +} diff --git a/src/main/java/org/qed/RelType.java b/src/main/java/org/qed/RelType.java new file mode 100644 index 0000000..d2b4328 --- /dev/null +++ b/src/main/java/org/qed/RelType.java @@ -0,0 +1,55 @@ +package org.qed; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeImpl; +import org.apache.calcite.rel.type.RelDataTypeSystem; +import org.apache.calcite.sql.type.BasicSqlType; +import org.apache.calcite.sql.type.SqlTypeName; + +public sealed interface RelType extends RelDataType { + static RelType fromString(String name, boolean nullable) { + for (var tn : SqlTypeName.values()) { + if (tn.getName().equals(name)) { + return new BaseType(tn, nullable); + } + } + return new VarType(name, nullable); + } + + final class VarType extends RelDataTypeImpl implements RelType { + private final String name; + private final boolean nullable; + + public VarType(String typeName, boolean nullability) { + name = typeName; + nullable = nullability; + computeDigest(); + } + + /* + * Notice: All virtual types will be translated to integer for prover + **/ + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + sb.append(name); + if (withDetail) { + sb.append(": ").append(nullable ? "nullable" : ""); + } + } + + @Override + public boolean isNullable() { + return nullable; + } + + public String getName() { + return "INTEGER"; + } + } + + final class BaseType extends BasicSqlType implements RelType { + public BaseType(SqlTypeName typeName, boolean nullable) { + super(RelDataTypeSystem.DEFAULT, typeName, nullable); + } + } +} diff --git a/src/main/java/org/qed/RexRN.java b/src/main/java/org/qed/RexRN.java new file mode 100644 index 0000000..060c50d --- /dev/null +++ b/src/main/java/org/qed/RexRN.java @@ -0,0 +1,157 @@ +package org.qed; + +import kala.collection.Seq; + +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.qed.RexRN.And; +import org.qed.RexRN.False; +import org.qed.RexRN.GroupBy; +import org.qed.RexRN.Pred; +import org.qed.RexRN.Proj; +import org.qed.RexRN.True; + +public interface RexRN { + + static RelType.VarType varType(String id, boolean nullable) { + return new RelType.VarType(id, nullable); + } + + static And and(RexRN... sources) { + return new And(Seq.from(sources)); + } + + static False falseLiteral() { + return new False(); + } + + static True trueLiteral() { + return new True(); + } + + RexNode semantics(); + + default Pred pred(SqlOperator op) { + return new Pred(op, Seq.of(this)); + } + + default Pred pred(String name) { + return pred(RuleBuilder.create().genericPredicateOp(name, true)); + } + + default Proj proj(SqlOperator op) { + return new Proj(op, Seq.of(this)); + } + + default Proj proj(String name, String type_name) { + return proj(RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(type_name, true))); + } + + default GroupBy groupBy(SqlOperator op) { + return new GroupBy(op, Seq.of(this)); + } + + default GroupBy groupBy(String name) { + return groupBy(RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(name + "_type", true))); + } + + default RelRN.AggCall aggCall(String name) { + return new RelRN.AggCall( + name, + RuleBuilder.create().genericAggregateOp(name, new RelType.VarType(name + "_type", true)), + false, + new RelType.VarType(name + "_type", true), + Seq.of(this) + ); + } + + + record Field(int ordinal, RelRN source) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().push(source.semantics()).field(ordinal); + } + } + + record JoinField(int ordinal, RelRN left, RelRN right) implements RexRN { + + @Override + public RexNode semantics() { + var leftCols = left.semantics().getRowType().getFieldCount(); + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).field(2, ordinal < leftCols ? + 0 : 1, ordinal < leftCols ? ordinal : ordinal - leftCols); + } + } + + record Pred(SqlOperator operator, Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + var builder = RuleBuilder.create(); + // builder.genericPredicateOp(name, nullable) + return builder.call(operator, sources.map(RexRN::semantics)); + } + } + + record Proj(SqlOperator operator, Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + var builder = RuleBuilder.create(); + // builder.genericProjectionOp(name, varType(type_name, nullable)) + return builder.call(operator, sources.map(RexRN::semantics)); + } + } + + record GroupBy(SqlOperator operator, Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + var builder = RuleBuilder.create(); + return builder.call(operator, sources.map(RexRN::semantics)); + } + } + + record And(Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().and(sources.map(RexRN::semantics)); + } + } + + record False() implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().literal(false); + } + } + + record Not(RexRN source) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().not(source.semantics()); + } + } + + record Or(Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().or(sources.map(RexRN::semantics)); + } + } + + record True() implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().literal(true); + } + } +} diff --git a/src/main/java/org/qed/RuleBuilder.java b/src/main/java/org/qed/RuleBuilder.java new file mode 100644 index 0000000..a48f869 --- /dev/null +++ b/src/main/java/org/qed/RuleBuilder.java @@ -0,0 +1,125 @@ +package org.qed; + +import kala.collection.Seq; +import kala.collection.Set; +import kala.tuple.Tuple; +import kala.tuple.Tuple2; +import kala.tuple.Tuple3; +import org.apache.calcite.plan.Context; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptSchema; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.*; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.tools.Frameworks; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.Optionality; +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +public class RuleBuilder extends RelBuilder { + + private final AtomicInteger TABLE_ID_GENERATOR = new AtomicInteger(); + + private final SchemaPlus root; + + protected RuleBuilder(@Nullable Context context, RelOptCluster cluster, RelOptSchema relOptSchema, + SchemaPlus schema) { + super(context, cluster, relOptSchema); + root = schema; + } + + public static RuleBuilder create() { + var emptySchema = Frameworks.createRootSchema(true); + var config = Frameworks.newConfigBuilder().defaultSchema(emptySchema).build(); + return Frameworks.withPrepare(config, + (cluster, relOptSchema, rootSchema, statement) -> new RuleBuilder(config.getContext(), cluster, + relOptSchema, emptySchema)); + } + + public RuleBuilder addTable(QedTable table) { + root.add(table.getName(), table); + return this; + } + + /** + * Create a qed table given the column types and whether they are unique (i.e. can be key) + * + * @param schema the list of column types and they are unique + * @return the table created from the given schema + */ + public QedTable createQedTable(Seq> schema) { + var identifier = "Table_" + TABLE_ID_GENERATOR.getAndIncrement(); + var cols = schema.mapIndexed( + (idx, tuple) -> Tuple.of(identifier + "_Column_" + idx, tuple.component1(), tuple.component2())); + return new QedTable(identifier, + cols.map(tuple -> Map.entry(tuple.component1(), tuple.component2())).toImmutableMap(), + Set.from(cols.filter(Tuple3::component3).map(tuple -> Set.of(tuple.component1()))), Set.of()); + } + + /** + * Create and return the names of the created simple tables after registering them to the builder + * + * @param typeIds the absolute value represents type id, while the sign indicates the uniqueness + * @return the names for the created tables + */ + public Seq sourceSimpleTables(Seq typeIds) { + return typeIds.map(id -> { + var identifier = "Table_" + TABLE_ID_GENERATOR.getAndIncrement(); + var colName = identifier + "_Column"; + var colType = new RelType.VarType("Type_" + (id < 0 ? -id : id), true); + var table = new QedTable(identifier, kala.collection.Map.of(colName, colType), + id < 0 ? Set.of(Set.of(colName)) : Set.of(), Set.empty()); + addTable(table); + return table.getName(); + }); + } + + public Seq joinFields() { + return Seq.from(fields(2, 0)).concat(fields(2, 1)); + } + + + public SqlAggFunction genericAggregateOp(String name, RelType aggregation) { + return new QedAggregateFunction(name, aggregation); + } + + public SqlOperator genericPredicateOp(String name, boolean nullable) { + return new QedFunction(name, new RelType.BaseType(SqlTypeName.BOOLEAN, nullable)); + } + + public SqlOperator genericProjectionOp(String name, RelType projection) { + return new QedFunction(name, projection); + } + + public static class QedFunction extends SqlFunction { + + private final RelType codomain; + + public QedFunction(String name, RelType returnType) { + super(name, SqlKind.OTHER_FUNCTION, opBinding -> { + var factory = opBinding.getTypeFactory(); + return factory.createTypeWithNullability(returnType, returnType.isNullable()); + }, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION); + codomain = returnType; + } + + public RelType getReturnType() { + return codomain; + } + } + + public static class QedAggregateFunction extends SqlAggFunction { + + public QedAggregateFunction(String name, RelType returnType) { + super(name, null, SqlKind.OTHER_FUNCTION, opBinding -> { + var factory = opBinding.getTypeFactory(); + return factory.createTypeWithNullability(returnType, returnType.isNullable()); + }, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION, false, false, Optionality.OPTIONAL); + } + } + +} \ No newline at end of file diff --git a/src/main/java/org/qed/SExpr.java b/src/main/java/org/qed/SExpr.java new file mode 100644 index 0000000..d530b76 --- /dev/null +++ b/src/main/java/org/qed/SExpr.java @@ -0,0 +1,88 @@ +package org.qed; + +import kala.collection.Seq; + +public sealed interface SExpr { + static Lst list(SExpr... elems) { + return new Lst(Seq.of(elems)); + } + + static Sym symbol(String name) { + return new Sym(name); + } + + static Str string(String value) { + return new Str(value); + } + + static Bool bool(boolean value) { + return new Bool(value); + } + + static Int integer(long value) { + return new Int(value); + } + + static Real real(double value) { + return new Real(value); + } + + static Lst app(String fn, SExpr... args) { + return app(fn, Seq.of(args)); + } + + static Lst app(String fn, Seq args) { + return new Lst(args.prepended(symbol(fn))); + } + + static Sym quoted(String sym) { + return symbol("'" + sym); + } + + static Lst def(String sym, SExpr expr) { + return SExpr.list(SExpr.symbol("define"), SExpr.symbol(sym), expr); + } + + record Lst(Seq nodes) implements SExpr { + @Override + public String toString() { + return nodes.map(Object::toString).joinToString(" ", "(", ")"); + } + } + + record Sym(String name) implements SExpr { + @Override + public String toString() { + return name; + } + } + + record Str(String value) implements SExpr { + @Override + public String toString() { + // TODO: Proper escaping + return "\"" + value + "\""; + } + } + + record Bool(boolean value) implements SExpr { + @Override + public String toString() { + return value ? "#t" : "#f"; + } + } + + record Int(long value) implements SExpr { + @Override + public String toString() { + return Long.toString(value); + } + } + + record Real(double value) implements SExpr { + @Override + public String toString() { + return Double.toString(value); + } + } +} diff --git a/src/main/java/org/qed/SQLJSONParser.java b/src/main/java/org/qed/SQLJSONParser.java new file mode 100644 index 0000000..2788c3b --- /dev/null +++ b/src/main/java/org/qed/SQLJSONParser.java @@ -0,0 +1,63 @@ +package org.qed; + +import com.fasterxml.jackson.databind.ObjectMapper; +import kala.collection.mutable.MutableArrayList; +import kala.collection.mutable.MutableList; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql2rel.RelFieldTrimmer; +import org.apache.calcite.tools.RelBuilder; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Paths; +import java.util.List; + +/** + * A SQLParse instance can parse DDL statements and valid DML statements into JSON format. + */ +public class SQLJSONParser { + + private final MutableList relNodes; + + /** + * Create a new instance with no RelNodes. + */ + public SQLJSONParser() { + relNodes = new MutableArrayList<>(); + } + + /** + * Create a new instance with the list of RelNodes within. + */ + public SQLJSONParser(List nodes) { + relNodes = MutableArrayList.from(nodes); + } + + /** + * Parse a DML statement with current schema. + * + * @param dml The DML statement to be parsed. + */ + public void parseDML(SchemaPlus context, String dml) throws Exception { + RawPlanner planner = new RawPlanner(context); + relNodes.append(planner.rel(planner.parse(dml))); + } + + /** + * Dump the parsed statements to a file. + * + * @param path The given file. + */ + public void dumpOutput(RelBuilder builder, String path) throws IOException { + var trimmer = new RelFieldTrimmer(null, builder); + var nodes = relNodes.map(trimmer::trim); + var scanner = new RelScanner(); + nodes.forEach(scanner::scan); + var pruner = new RelPruner(scanner.usages().toImmutableMap()); + var rNodes = nodes.map(pruner); + new ObjectMapper().writerWithDefaultPrettyPrinter() + .writeValue(new File(path + ".json"), JSONSerializer.serialize(rNodes)); + RelRacketShuttle.dumpToRacket(rNodes.asJava(), Paths.get(path + ".rkt")); + } +} diff --git a/src/main/java/org/qed/SchemaGenerator.java b/src/main/java/org/qed/SchemaGenerator.java new file mode 100644 index 0000000..484b1ef --- /dev/null +++ b/src/main/java/org/qed/SchemaGenerator.java @@ -0,0 +1,276 @@ +package org.qed; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import kala.collection.immutable.ImmutableSet; +import kala.collection.mutable.MutableHashMap; +import kala.collection.mutable.MutableList; +import kala.collection.mutable.MutableMap; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.util.Quoting; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.config.Lex; +import org.apache.calcite.jdbc.CalciteConnection; +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.type.RelDataTypeImpl; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.schema.*; +import org.apache.calcite.schema.impl.*; +import org.apache.calcite.sql.*; +import org.apache.calcite.sql.ddl.*; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * A SchemaGenerator instance can execute DDL statements and generate schemas in the process. + */ +public class SchemaGenerator { + + private static final Map> toPrimitive = + ImmutableMap.>builder().put("BINARY", String.class).put("CHAR", String.class) + .put("VARBINARY", String.class).put("VARCHAR", String.class).put("BLOB", String.class) + .put("TINYBLOB", String.class).put("MEDIUMBLOB", String.class).put("LONGBLOB", String.class) + .put("TEXT", String.class).put("TINYTEXT", String.class).put("MEDIUMTEXT", String.class) + .put("LONGTEXT", String.class).put("ENUM", String.class).put("SET", String.class) + .put("BOOL", boolean.class).put("BOOLEAN", boolean.class).put("DEC", double.class) + .put("DECIMAL", double.class).put("DOUBLE", double.class).put("DOUBLE PRECISION", double.class) + .put("FLOAT", float.class).put("DATE", int.class).put("DATETIME", int.class) + .put("TIMESTAMP", int.class).put("TIME", int.class).put("YEAR", int.class).put("INT", int.class) + .put("TINYINT", int.class).put("SMALLINT", int.class).put("MEDIUMINT", int.class) + .put("BIGINT", int.class).put("INTEGER", int.class).build(); + private static final Pattern functionPattern = Pattern.compile( + "(?i)DECLARE\\s+(?SCALAR|AGGREGATE)\\s+FUNCTION\\s+(?\\w+)\\s*\\((?.*)\\)" + + "\\s+RETURNS\\s+(?.+)"); + private static final SqlParser.Config schemaParserConfig = + SqlParser.Config.DEFAULT.withParserFactory(SqlDdlParserImpl.FACTORY).withLex(Lex.MYSQL) + .withQuoting(Quoting.DOUBLE_QUOTE); + private final QedSchema schema; + private final Map declaredFunctions = new HashMap<>(); + + /** + * Create a SchemaGenerator instance by setting up a connection to JDBC. + */ + public SchemaGenerator() { + schema = new QedSchema(this); + } + + /** + * Execute a CREATE statement. + * + * @param create The given CREATE statement. + */ + public void applyCreate(String create) throws SQLException { + Pattern supported = Pattern.compile("(?i)CREATE\\s+(VIEW|TABLE)"); + if (!supported.matcher(create).find()) { + // TODO: Improve error handling + return; + } + SqlParser schemaParser = SqlParser.create(create, schemaParserConfig); + SqlNode schemaNode; + try { + schemaNode = schemaParser.parseStmt(); + } catch (Exception e) { + System.err.println("Warning: Skipping problematic statement:\n" + create); + System.err.println(e + "\n"); + return; + } + switch (schemaNode) { + case SqlCreateTable sqlCreateTable -> schema.addTable(sqlCreateTable); + case SqlCreateView sqlCreateView -> { + try { + schema.addView(sqlCreateView, create); + } catch (Exception e) { + System.err.println("Warning: Encountered problematic view definition:\n" + create); + System.err.println(e.getCause() + "\n"); + } + } + default -> throw new RuntimeException("Unsupported create statement:\n" + create); + } + } + + /** + * Execute a DECLARE FUNCTION statement. + * + * @param declareFunction The given DECLARE FUNCTION statement. + */ + public void applyDeclareFunction(String declareFunction) throws Exception { + Matcher matcher = functionPattern.matcher(declareFunction); + if (!matcher.find()) { + throw new RuntimeException("Broken function declaration:\n" + declareFunction); + } + String identifier = matcher.group("identifier"); + String[] source = matcher.group("source").split(","); + String target = matcher.group("target").split("\\(")[0].trim().toUpperCase(); + Class[] parameters = new Class[source.length]; + if (!toPrimitive.containsKey(target)) { + throw new RuntimeException("Invalid return type: " + target); + } + for (int i = 0; i < source.length; i += 1) { + String arg = source[i].split("\\(")[0].trim().toUpperCase(); + if (!toPrimitive.containsKey(arg)) { + throw new RuntimeException("Invalid argument type: " + arg); + } + parameters[i] = toPrimitive.get(arg); + } + Function customFunction; + Constructor methodConstructor = + Method.class.getDeclaredConstructor(Class.class, String.class, Class[].class, Class.class, + Class[].class, int.class, int.class, String.class, byte[].class, byte[].class, byte[].class); + methodConstructor.setAccessible(true); + if (matcher.group("type").equalsIgnoreCase("SCALAR")) { + Method scalarFunction = methodConstructor.newInstance(SchemaGenerator.class, "qedFunction", parameters, + toPrimitive.get(target), null, 0, 0, "", null, null, null); + customFunction = ScalarFunctionImpl.createUnsafe(scalarFunction); + } else { + ReflectiveFunctionBase.ParameterListBuilder sourceParameters = ReflectiveFunctionBase.builder(); + ImmutableList.Builder> sourceTypes = ImmutableList.builder(); + for (Class clazz : parameters) { + sourceParameters.add(clazz, clazz.getName(), false); + sourceTypes.add(clazz); + } + Method nullFunction = methodConstructor.newInstance(SchemaGenerator.class, "qedFunction", parameters, + toPrimitive.get(target), null, 0, 0, "", null, null, null); + Constructor aggregateFunctionConstructor = + AggregateFunctionImpl.class.getDeclaredConstructor(Class.class, List.class, List.class, Class.class, + Class.class, Method.class, Method.class, Method.class, Method.class); + aggregateFunctionConstructor.setAccessible(true); + customFunction = aggregateFunctionConstructor.newInstance(SchemaGenerator.class, sourceParameters.build(), + sourceTypes.build(), toPrimitive.get(target), toPrimitive.get(target), nullFunction, nullFunction, + null, null); + } + declaredFunctions.put(identifier, customFunction); + } + + /** + * @return The current schema. + */ + public SchemaPlus extractSchema() { + return schema.plus(); + } + + /** + * @return The declared custom store. + */ + public Map customFunctions() { + return declaredFunctions; + } + +} + +class QedSchema extends AbstractSchema { + + final MutableMap tables = new MutableHashMap<>(); + final SchemaGenerator owner; + + public QedSchema(SchemaGenerator source) { + owner = source; + } + + public void addTable(SqlCreateTable createTable) { + if (createTable.columnList == null) { + throw new RuntimeException("No column in table " + createTable.name); + } + var planner = new RawPlanner(this.plus()); + var names = MutableList.create(); + var types = MutableList.create(); + var nullabilities = MutableList.create(); + var keys = MutableList.create(); + var checkConstraints = MutableList.create(); + for (SqlNode column : createTable.columnList) { + switch (column.getKind()) { + case CHECK -> { + var check = (SqlBasicCall) ((SqlCheckConstraint) column).getOperandList().get(1); + var wrapper = new SqlSelect(SqlParserPos.ZERO, SqlNodeList.EMPTY, SqlNodeList.SINGLETON_STAR, + createTable.name, check, null, null, SqlNodeList.EMPTY, null, null, null, null); + try { + var filter = (LogicalFilter) planner.rel(wrapper).getInput(0); + checkConstraints.append(filter.getCondition()); + } catch (Exception ignore) { + } + } + case COLUMN_DECL -> { + SqlColumnDeclaration decl = (SqlColumnDeclaration) column; + names.append(decl.name.toString()); + types.append(SqlTypeName.get(decl.dataType.getTypeName().toString())); + nullabilities.append(decl.strategy != ColumnStrategy.NOT_NULLABLE); + } + case FOREIGN_KEY -> System.err.println("Foreign key constraint is not implemented in qed yet."); + case PRIMARY_KEY, UNIQUE -> { + SqlKeyConstraint cons = (SqlKeyConstraint) column; + List key = new ArrayList<>(); + for (SqlNode id : (SqlNodeList) cons.getOperandList().get(1)) { + int index = names.indexOf(id.toString()); + key.add(index); + if (column.getKind() == SqlKind.PRIMARY_KEY) { + nullabilities.set(index, false); + } + } + keys.append(ImmutableBitSet.of(key)); + } + default -> throw new RuntimeException( + "Unsupported declaration type " + column.getKind() + " in table " + createTable.name); + } + } + var qedTable = new QedTable(createTable.name.toString(), names.zip( + types.zip(nullabilities).map(type -> new RelType.BaseType(type.component1(), + type.component2()))) + .toImmutableMap(), ImmutableSet.from(keys), ImmutableSet.from(checkConstraints)); + tables.put(createTable.name.toString(), qedTable); + } + + public void addView(SqlCreateView sqlCreateView, String rawDef) throws SQLException { + if (sqlCreateView.columnList == null || sqlCreateView.columnList.getList().isEmpty()) { + throw new RuntimeException("No field definition in view " + sqlCreateView.name); + } + // Some regex hackery to extract the raw definition... + var matcher = Pattern.compile("(?s).*?\\(.*?\\)\\s+[Aa][Ss](.*)").matcher(rawDef); + if (!matcher.find()) { + throw new RuntimeException("Cannot extract definition of view " + sqlCreateView.name); + } + var rawQuery = matcher.group(1); + String fields = sqlCreateView.columnList.getList().stream().filter(Objects::nonNull).map(SqlNode::toString) + .collect(Collectors.joining("\", \"")); + String wrapper = "SELECT * FROM (%s) AS \"_\" (\"%s\")".formatted(rawQuery, fields); + Properties info = new Properties(); + info.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName(), "FALSE"); + CalciteConnection connection = + DriverManager.getConnection("jdbc:calcite:", info).unwrap(CalciteConnection.class); + CalciteSchema calciteSchema = CalciteSchema.from(plus()); + CalcitePrepare.AnalyzeViewResult parsed = + Schemas.analyzeView(connection, calciteSchema, null, wrapper, null, false); + JavaTypeFactory typeFactory = (JavaTypeFactory) parsed.typeFactory; + Type elementType = typeFactory.getJavaClass(parsed.rowType); + Table viewTable = + new ViewTable(elementType, RelDataTypeImpl.proto(parsed.rowType), wrapper, calciteSchema.path(null), + null); + tables.put(sqlCreateView.name.toString(), viewTable); + } + + protected Map getTableMap() { + return tables.asJava(); + } + + public SchemaPlus plus() { + SchemaPlus plus = CalciteSchema.createRootSchema(true, false, "Qed", this).plus(); + for (String fn : owner.customFunctions().keySet()) { + plus.add(fn, owner.customFunctions().get(fn)); + } + return plus; + } + +} diff --git a/src/main/java/org/qed/UnprovableRRuleInstances/JoinAssociate.java b/src/main/java/org/qed/UnprovableRRuleInstances/JoinAssociate.java new file mode 100644 index 0000000..c0c08b7 --- /dev/null +++ b/src/main/java/org/qed/UnprovableRRuleInstances/JoinAssociate.java @@ -0,0 +1,69 @@ +package org.qed.UnprovableRRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinAssociate() implements RRule.RRuleFamily { + static final RelRN a = RelRN.scan("A", "A_Type"); + static final RelRN b = RelRN.scan("B", "B_Type"); + static final RelRN c = RelRN.scan("C", "C_Type"); + static final String pred_ab = "pred_ab"; + static final String pred_bc = "pred_bc"; + static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); + static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); + static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); + static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); + + static final RelRN before_ab = a.join(mjt_0, RexRN.and( + a.joinPred(pred_ab, b), + new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), b); + + static final RelRN before = before_ab.join(mjt_1, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), + new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after_bc = b.join(mjt_2, RexRN.and( + b.joinPred(pred_bc, c), + new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after = a.join(mjt_3, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), + new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), after_bc); + + static final RRule template = new RRule() { + @Override + public RelRN before() { + return before; + } + + @Override + public RelRN after() { + return after; + } + + @Override + public String name() { + return JoinAssociate.class.getSimpleName(); + } + }; + + static Seq assignments() { + var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); + return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); + } + + @Override + public Seq family() { + return new RRule.RRuleGenerator(template, assignments()).family(); + } +} diff --git a/src/main/java/org/qed/UnprovableRRuleInstances/PruneLeftEmptyJoin.java b/src/main/java/org/qed/UnprovableRRuleInstances/PruneLeftEmptyJoin.java new file mode 100644 index 0000000..d6af2aa --- /dev/null +++ b/src/main/java/org/qed/UnprovableRRuleInstances/PruneLeftEmptyJoin.java @@ -0,0 +1,20 @@ +package org.qed.UnprovableRRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; + +public record PruneLeftEmptyJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.empty().join(JoinRelType.RIGHT, "pred", right); + } + + @Override + public RelRN after() { + return right; + } +} diff --git a/src/main/java/org/qed/UnprovableRRuleInstances/PruneRightEmptyJoin.java b/src/main/java/org/qed/UnprovableRRuleInstances/PruneRightEmptyJoin.java new file mode 100644 index 0000000..46527f7 --- /dev/null +++ b/src/main/java/org/qed/UnprovableRRuleInstances/PruneRightEmptyJoin.java @@ -0,0 +1,22 @@ +package org.qed.UnprovableRRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneRightEmptyJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right.empty()); + + @Override + public RelRN before() { + return left.join(JoinRelType.LEFT, "pred", right.empty()); + } + + @Override + public RelRN after() { + return left; + } +} diff --git a/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinJoinTranspose.java b/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinJoinTranspose.java new file mode 100644 index 0000000..3ecc99a --- /dev/null +++ b/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinJoinTranspose.java @@ -0,0 +1,25 @@ +package org.qed.UnprovableRRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + + +public record SemiJoinJoinTranspose() implements RRule { + static final RelRN left = RelRN.scan("left", "left_Type"); + static final RelRN middle = RelRN.scan("middle", "middle_Type"); + static final RelRN right = RelRN.scan("right", "right_Type"); + static final RexRN semiCond = left.joinPred("semi", middle); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right).join(JoinRelType.SEMI, semiCond, middle); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, semiCond, middle).join(JoinRelType.INNER, joinCond, right); + } +} diff --git a/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinProjectTranspose.java b/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinProjectTranspose.java new file mode 100644 index 0000000..357aef7 --- /dev/null +++ b/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinProjectTranspose.java @@ -0,0 +1,23 @@ +package org.qed.UnprovableRRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record SemiJoinProjectTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "left_type"); + static final RelRN right = RelRN.scan("Right", "right_type"); + static final RexRN proj = left.proj("proj", "proj_type"); + static final RexRN semiCond = left.joinPred("semi", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, semiCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.SEMI, semiCond, right); + } +} diff --git a/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinRemove.java b/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinRemove.java new file mode 100644 index 0000000..27af3ba --- /dev/null +++ b/src/main/java/org/qed/UnprovableRRuleInstances/SemiJoinRemove.java @@ -0,0 +1,21 @@ +package org.qed.UnprovableRRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record SemiJoinRemove() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); + } + + @Override + public RelRN after() { + return left; + } +} From 1284b73af777d37bf953877de5180972f950e561 Mon Sep 17 00:00:00 2001 From: Wesley Zheng Date: Tue, 5 May 2026 18:10:28 -0700 Subject: [PATCH 2/5] Adding workflows back --- .github/workflows/codegen-test.yml | 42 ++++++++++++++++++ .github/workflows/prover-test.yml | 58 +++++++++++++++++++++++++ .gitignore | 2 + src/main/java/org/qed/ProjectPaths.java | 19 -------- 4 files changed, 102 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/codegen-test.yml create mode 100644 .github/workflows/prover-test.yml delete mode 100644 src/main/java/org/qed/ProjectPaths.java diff --git a/.github/workflows/codegen-test.yml b/.github/workflows/codegen-test.yml new file mode 100644 index 0000000..3d76901 --- /dev/null +++ b/.github/workflows/codegen-test.yml @@ -0,0 +1,42 @@ +name: Test Code Generation + +on: + push: + pull_request: + +jobs: + test-code-generation: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + java-version: '25' + distribution: 'temurin' + + - name: Cache Maven dependencies + uses: actions/cache@v3 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + + - name: Build project + run: mvn -B compile --file pom.xml + + - name: Run Calcite tests + run: | + chmod +x scripts/test-codegen.sh + ./scripts/test-codegen.sh + + - name: Upload generated code + if: always() + uses: actions/upload-artifact@v4 + with: + name: generated-code + path: | + src/main/java/org/qed/Backends/Calcite/Generated/*.java \ No newline at end of file diff --git a/.github/workflows/prover-test.yml b/.github/workflows/prover-test.yml new file mode 100644 index 0000000..3b240c6 --- /dev/null +++ b/.github/workflows/prover-test.yml @@ -0,0 +1,58 @@ +name: Test Provability + +on: + push: + pull_request: + +jobs: + test-provability: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + java-version: '25' + distribution: 'temurin' + + - name: Cache Maven dependencies + uses: actions/cache@v3 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + + - name: Build project + run: mvn -B compile --file pom.xml + + - name: Generate JSON for all rules + run: | + mkdir -p tmp-rules + mvn dependency:resolve + chmod +x scripts/generate-rule-json.sh + ./scripts/generate-rule-json.sh + + - name: Install dependencies + run: | + chmod +x scripts/install-dependencies.sh + ./scripts/install-dependencies.sh + + - name: Build qed-prover + run: | + chmod +x scripts/build-qed-prover.sh + ./scripts/build-qed-prover.sh + + - name: Test all rules + run: | + chmod +x scripts/test-rules.sh + ./scripts/test-rules.sh + + - name: Upload test artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: results + path: tmp-rules/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index eb0115b..c9980ef 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ * +!.github +!.github/** !.envrc !.gitignore !flake.nix diff --git a/src/main/java/org/qed/ProjectPaths.java b/src/main/java/org/qed/ProjectPaths.java deleted file mode 100644 index 642518f..0000000 --- a/src/main/java/org/qed/ProjectPaths.java +++ /dev/null @@ -1,19 +0,0 @@ -package org.qed; - -import java.nio.file.Path; - -/** - * Resolves repo-relative paths for codegen and tests. Maven sets {@code -Drulescript.basedir}; - * otherwise {@code user.dir} is used (run from repository root). - */ -public final class ProjectPaths { - private ProjectPaths() {} - - public static Path baseDir() { - String override = System.getProperty("rulescript.basedir"); - if (override != null && !override.isBlank()) { - return Path.of(override); - } - return Path.of(System.getProperty("user.dir")); - } -} From 64b5f602398f278d5464fe99f5d2396b5874b44c Mon Sep 17 00:00:00 2001 From: Wesley Zheng Date: Tue, 5 May 2026 18:11:42 -0700 Subject: [PATCH 3/5] Create ProjectPaths.java --- src/main/java/org/qed/ProjectPaths.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 src/main/java/org/qed/ProjectPaths.java diff --git a/src/main/java/org/qed/ProjectPaths.java b/src/main/java/org/qed/ProjectPaths.java new file mode 100644 index 0000000..efcfb2e --- /dev/null +++ b/src/main/java/org/qed/ProjectPaths.java @@ -0,0 +1,19 @@ +package org.qed; + +import java.nio.file.Path; + +/** + * Repo-relative paths for codegen and tests. Defaults to {@code user.dir} (run from repository root). + * Optional override: {@code -Drulescript.basedir=/path}. + */ +public final class ProjectPaths { + private ProjectPaths() {} + + public static Path baseDir() { + String override = System.getProperty("rulescript.basedir"); + if (override != null && !override.isBlank()) { + return Path.of(override); + } + return Path.of(System.getProperty("user.dir")); + } +} From b7706645ffcb7e45bf3c135b1d6853fed1d7e016 Mon Sep 17 00:00:00 2001 From: Wesley Zheng Date: Tue, 5 May 2026 18:56:14 -0700 Subject: [PATCH 4/5] Edit scripts --- .gitignore | 22 +---- LICENSE | 176 ++++++++++++++++++++++++++++++++++ README.md | 70 ++++---------- pom.xml | 6 ++ scripts/build-qed-prover.sh | 24 ++++- scripts/generate-rule-json.sh | 2 +- scripts/test-codegen.sh | 2 +- scripts/test-rules.sh | 32 +++++-- 8 files changed, 244 insertions(+), 90 deletions(-) create mode 100644 LICENSE diff --git a/.gitignore b/.gitignore index c9980ef..1ba256f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,13 +3,10 @@ !.github/** !.envrc !.gitignore +!LICENSE !flake.nix !flake.lock !*.md - -# Rust -!Cargo.toml -!Cargo.lock !pom.xml !mvnw !mvnw.cmd @@ -18,24 +15,7 @@ .mvn/wrapper/maven-wrapper.jar !scripts !scripts/** -!examples -!examples/** !src !src/** -# Java parser (RuleScript / Qed) -!parser -!parser/** - -# Parser: build & IDE (after !parser/**) -parser/target/ -parser/.idea/ -parser/*.iml -parser/.mvn/wrapper/maven-wrapper.jar -parser/.vscode -parser/.devcontainer -parser/.direnv/ -parser/.envrc - -# OS noise under tracked trees **/.DS_Store diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d9a10c0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,176 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/README.md b/README.md index 685510b..b99695d 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ # RuleScript -RuleScript is an engine-agnostic domain-specific language (DSL) for developing query rewrite rules. +RuleScript is an engine-agnostic domain-specific language for developing query rewrite rules. For details, please see our [paper](http://www2.eecs.berkeley.edu/Pubs/TechRpts/2024/EECS-2024-140.pdf). ## Build -The project targets **Java 25** ([OpenJDK](https://openjdk.org/) / Temurin builds). Build with Maven: +The project targets Java 25. Build with Maven: ```sh ./mvnw compile -q @@ -13,23 +13,13 @@ The project targets **Java 25** ([OpenJDK](https://openjdk.org/) / Temurin build ## Generate Rules -Rules are generated per backend by running the corresponding tester. First build a classpath: +Rules are generated per backend by running the corresponding tester: ```sh -./mvnw dependency:build-classpath -q -DincludeTypes=jar -Dmdep.outputFile=/tmp/cp.txt -``` - -Then run the tester for the target backend: - -```sh -# CockroachDB -java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.Cockroach.CockroachTester - -# Apache Calcite -java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.Calcite.CalciteTester - -# MySQL -java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.MySQL.Tests.MySQLTester +./mvnw -q compile exec:java@cockroach-codegen # CockroachDB +./mvnw -q compile exec:java@calcite-codegen-test # Apache Calcite +./mvnw -q compile exec:java@mysql-tester # MySQL +# See the Datafusion folder for details about RuleScript generation for DataFusion ``` Generated rule files are written to each backend's `Generated/` directory. @@ -38,7 +28,7 @@ Generated rule files are written to each backend's `Generated/` directory. Rules are defined in `src/main/java/org/qed/RRuleInstances/` as Java records implementing `RRule`. Each rule provides a `before()` pattern and an `after()` transformation in terms of RuleScript's relational algebra operators. The generators pick up every file in that directory automatically. -**Example: `FilterMerge`** +Example: `FilterMerge` ```java // src/main/java/org/qed/RRuleInstances/FilterMerge.java @@ -63,51 +53,23 @@ Running the generators will produce: - `src/main/java/org/qed/Backends/Calcite/Generated/FilterMerge.java` — the Apache Calcite rule implementation - `src/main/java/org/qed/Backends/Cockroach/Generated/FilterMerge.opt` — the CockroachDB optgen rule -To also add a Calcite test, create `src/main/java/org/qed/Backends/Calcite/Tests/FilterMergeTest.java` with a `public static void runTest()` method that constructs `before` and `after` plans using `RuleBuilder` and calls `tester.verify(runner, before, after)`. The CalciteTester discovers and runs all `*Test.java` files in that directory automatically. - For a full description of the rule language and available operators, see the [paper](http://www2.eecs.berkeley.edu/Pubs/TechRpts/2024/EECS-2024-140.pdf). -## Apache DataFusion Backend - -A separate Rust implementation targeting Apache DataFusion is available at [here](https://github.com/qed-solver/rulescript). +## Qed Proofs on Rules -## Test Cases +RuleScript turns each `RRule` into Qed JSON and runs the Rust [Qed prover](https://github.com/qed-solver/prover) against it to check Ged-level provability of the before/after pair. -### Apache Calcite +You will need to install `jq`, `z3`, and `cvc5` yourself and put them on `PATH`. Read [qed-solver/prover](https://github.com/qed-solver/prover) for how to install compatible versions. -Individual rule tests live in `src/main/java/org/qed/Backends/Calcite/Tests/`. All tests are run automatically when the Calcite tester is invoked: +After you add or change rules as Java records in `src/main/java/org/qed/RRuleInstances/`, run the following from the repository root: ```sh -java -cp "target/classes:$(cat /tmp/cp.txt)" org.qed.Backends.Calcite.CalciteTester +./mvnw compile +bash scripts/generate-rule-json.sh # Qed JSON under tmp-rules/ +bash scripts/build-qed-prover.sh # clone ./qed-prover and build target/release/qed-prover (skip if already built) +bash scripts/test-rules.sh # run the prover on tmp-rules/*.json ``` -### CockroachDB - -The generated `.opt` rule files live in `src/main/java/org/qed/Backends/Cockroach/Generated/`. - -To run them against CockroachDB: - -1. Clone the [CockroachDB repository](https://github.com/cockroachdb/cockroach) and check out commit `4b80cd59c6299f26b2b4f02a96064d5127ccad94` — this is the exact state of the codebase the rules were developed against. - -2. Copy the generated rule files and test data into the CockroachDB source tree: - - Rule files → `pkg/sql/opt/norm/rules/` - - Test data → `pkg/sql/opt/norm/testdata/rules/CockroachTests` - -3. Check your environment is set up correctly: - ```sh - ./dev doctor - ``` - -4. Build CockroachDB: - ```sh - ./dev build - ``` - -5. Run the CockroachDB tests: - ```sh - ./dev test pkg/sql/opt/norm -f=TestNormRules/CockroachTests -v - ``` - ## License Copyright 2026 The Qed Team diff --git a/pom.xml b/pom.xml index 115abe5..dfd656a 100644 --- a/pom.xml +++ b/pom.xml @@ -44,6 +44,12 @@ org.qed.Backends.Cockroach.CockroachTester + + mysql-tester + + org.qed.Backends.MySQL.Tests.MySQLTester + + qed-parser-main diff --git a/scripts/build-qed-prover.sh b/scripts/build-qed-prover.sh index 10fe4b0..17d98fa 100644 --- a/scripts/build-qed-prover.sh +++ b/scripts/build-qed-prover.sh @@ -3,8 +3,24 @@ # Build qed-prover with Rust nightly curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly -source $HOME/.cargo/env +source "${HOME}/.cargo/env" -git clone https://github.com/qed-solver/prover.git qed-prover -cd qed-prover -cargo +nightly build --release \ No newline at end of file +if [[ "$(uname -s)" == "Darwin" ]] && command -v brew >/dev/null 2>&1; then + if brew --prefix z3 >/dev/null 2>&1; then + _z3="$(brew --prefix z3)" + export CPATH="${_z3}/include${CPATH:+:${CPATH}}" + export LIBRARY_PATH="${_z3}/lib${LIBRARY_PATH:+:${LIBRARY_PATH}}" + fi + if brew --prefix llvm >/dev/null 2>&1; then + export LIBCLANG_PATH="$(brew --prefix llvm)/lib" + fi +fi + +ROOT_DIR="$(pwd)" +PROVER_DIR="${ROOT_DIR}/qed-prover" +if [[ ! -d "${PROVER_DIR}/.git" ]]; then + git clone https://github.com/qed-solver/prover.git "${PROVER_DIR}" +fi + +cd "${PROVER_DIR}" +cargo +nightly build --release diff --git a/scripts/generate-rule-json.sh b/scripts/generate-rule-json.sh index 172b0e5..544846e 100644 --- a/scripts/generate-rule-json.sh +++ b/scripts/generate-rule-json.sh @@ -26,7 +26,7 @@ public class JsonGenerator { EOF # Build classpath -MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +MAVEN_CP=$(./mvnw dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) CLASSPATH="target/classes:${MAVEN_CP}" # Compile the generator diff --git a/scripts/test-codegen.sh b/scripts/test-codegen.sh index 064b171..3d985bd 100644 --- a/scripts/test-codegen.sh +++ b/scripts/test-codegen.sh @@ -27,7 +27,7 @@ public class RuleGenerator { EOF # Build classpath -MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +MAVEN_CP=$(./mvnw dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) CLASSPATH="target/classes:${MAVEN_CP}" # Compile the generator diff --git a/scripts/test-rules.sh b/scripts/test-rules.sh index 5dc149b..d27b9e4 100644 --- a/scripts/test-rules.sh +++ b/scripts/test-rules.sh @@ -2,8 +2,18 @@ # Test all generated rules with qed-prover -echo "## QED Prover Test Results" >> $GITHUB_STEP_SUMMARY -echo "" >> $GITHUB_STEP_SUMMARY +if [ -z "${GITHUB_STEP_SUMMARY:-}" ]; then + GITHUB_STEP_SUMMARY="tmp-rules/qed-prover-step-summary.md" +fi +mkdir -p "$(dirname "$GITHUB_STEP_SUMMARY")" + +log_line() { + printf '%s\n' "$@" + printf '%s\n' "$@" >> "$GITHUB_STEP_SUMMARY" +} + +log_line "## QED Prover Test Results" +log_line "" failed_rules="" total_count=0 @@ -13,21 +23,25 @@ for json_file in tmp-rules/*.json; do rule_name=$(basename "$json_file" .json) total_count=$((total_count + 1)) ./qed-prover/target/release/qed-prover "$json_file" || true - + result_file="${json_file%.json}.result" if [ -f "$result_file" ] && jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then - echo "✅ $rule_name: PASSED" >> $GITHUB_STEP_SUMMARY + log_line "✅ $rule_name: PASSED" passed_count=$((passed_count + 1)) else - echo "❌ $rule_name: FAILED" >> $GITHUB_STEP_SUMMARY + log_line "❌ $rule_name: FAILED" failed_rules="$failed_rules$rule_name," fi done -echo "" >> $GITHUB_STEP_SUMMARY -echo "**Summary:** $passed_count/$total_count passed" >> $GITHUB_STEP_SUMMARY +log_line "" +log_line "**Summary:** $passed_count/$total_count passed" if [ -n "$failed_rules" ]; then - echo "::error::Failed rules: ${failed_rules%,}" + msg="Failed rules: ${failed_rules%,}" + echo "$msg" >&2 + if [ -n "${GITHUB_ACTIONS:-}" ]; then + echo "::error::$msg" + fi exit 1 -fi \ No newline at end of file +fi From 8e0990ccab7aae9bcacf2b6791cc2005cac2516b Mon Sep 17 00:00:00 2001 From: Wesley Zheng Date: Tue, 5 May 2026 19:04:54 -0700 Subject: [PATCH 5/5] Adjusting readmes --- README.md | 4 +- .../java/org/qed/Backends/Calcite/README.md | 30 +++++++++++++ .../java/org/qed/Backends/Cockroach/README.md | 44 +++++++++++++++++++ .../java/org/qed/Backends/MySQL/README.md | 23 ++++++++++ .../java/org/qed/Backends/ProxySQL/README.md | 29 ++++++++++++ 5 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/qed/Backends/Calcite/README.md create mode 100644 src/main/java/org/qed/Backends/Cockroach/README.md create mode 100644 src/main/java/org/qed/Backends/MySQL/README.md create mode 100644 src/main/java/org/qed/Backends/ProxySQL/README.md diff --git a/README.md b/README.md index b99695d..7a3bfcc 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ For a full description of the rule language and available operators, see the [pa RuleScript turns each `RRule` into Qed JSON and runs the Rust [Qed prover](https://github.com/qed-solver/prover) against it to check Ged-level provability of the before/after pair. -You will need to install `jq`, `z3`, and `cvc5` yourself and put them on `PATH`. Read [qed-solver/prover](https://github.com/qed-solver/prover) for how to install compatible versions. +You will need to install `jq`, `z3`, and `cvc5` yourself and put them on `PATH`. Read [qed-solver/prover](https://github.com/qed-solver/prover) for more details. After you add or change rules as Java records in `src/main/java/org/qed/RRuleInstances/`, run the following from the repository root: @@ -70,6 +70,8 @@ bash scripts/build-qed-prover.sh # clone ./qed-prover and build target/rele bash scripts/test-rules.sh # run the prover on tmp-rules/*.json ``` +After `scripts/test-rules.sh` finishes, the markdown summary is written to `tmp-rules/qed-prover-step-summary.md`. + ## License Copyright 2026 The Qed Team diff --git a/src/main/java/org/qed/Backends/Calcite/README.md b/src/main/java/org/qed/Backends/Calcite/README.md new file mode 100644 index 0000000..572f426 --- /dev/null +++ b/src/main/java/org/qed/Backends/Calcite/README.md @@ -0,0 +1,30 @@ +# Apache Calcite + +This directory contains Apache Calcite-specific artifacts generated from RuleScript rules. + +## Generate Rules and Execute Tests + +Run the Calcite generator from the repository root: + +```sh +./mvnw -q compile exec:java@calcite-codegen-test +``` + +This command does two things: + +1. Generates Calcite rule classes into: + - `src/main/java/org/qed/Backends/Calcite/Generated/*.java` +2. Runs all Calcite backend tests discovered under: + - `src/main/java/org/qed/Backends/Calcite/Tests/*Test.java` + +## Writing Calcite Tests + +Calcite tests are Java classes in `src/main/java/org/qed/Backends/Calcite/Tests/` with a public static `runTest()` method. + +Typical flow in each test: + +1. Build `before` and `after` plans with `RuleBuilder`. +2. Load the generated rule into a `HepPlanner` through `CalciteTester`. +3. Call `tester.verify(runner, before, after)`. + +`CalciteTester.runAllTests()` reflects over all `*Test.java` files and invokes `runTest()` automatically during `calcite-codegen-test`. diff --git a/src/main/java/org/qed/Backends/Cockroach/README.md b/src/main/java/org/qed/Backends/Cockroach/README.md new file mode 100644 index 0000000..294d3ab --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/README.md @@ -0,0 +1,44 @@ +# CockroachDB + +This directory contains CockroachDB-specific artifacts generated from RuleScript rules. + +## Generate Rules + +Run the Cockroach generator from the repository root: + +```sh +./mvnw -q compile exec:java@cockroach-codegen +``` + +Generated optgen rules are written to: + +- `src/main/java/org/qed/Backends/Cockroach/Generated/*.opt` + +Cockroach test cases are maintained in: + +- `src/main/java/org/qed/Backends/Cockroach/CockroachTests` + +## Running against CockroachDB + +The workflow below is the Cockroach-specific setup that used to live in the root README. + +1. Clone [cockroachdb/cockroach](https://github.com/cockroachdb/cockroach) and check out commit: + + ```text + 4b80cd59c6299f26b2b4f02a96064d5127ccad94 + ``` + +2. Copy RuleScript outputs into the Cockroach tree: + + - Rule files from `Generated/*.opt` -> `pkg/sql/opt/norm/rules/` + - `CockroachTests` -> `pkg/sql/opt/norm/testdata/rules/CockroachTests` + +3. In the Cockroach repository: + + ```sh + ./dev doctor + ./dev build + ./dev test pkg/sql/opt/norm -f=TestNormRules/CockroachTests -v + ``` + +This validates that generated rules compile and behave as expected in Cockroach's optimizer test harness. diff --git a/src/main/java/org/qed/Backends/MySQL/README.md b/src/main/java/org/qed/Backends/MySQL/README.md new file mode 100644 index 0000000..6b06970 --- /dev/null +++ b/src/main/java/org/qed/Backends/MySQL/README.md @@ -0,0 +1,23 @@ +# MySQL + +This directory contains MySQL-specific artifacts produced from RuleScript rules. + +## Status + +This backend is currently legacy and is kept for compatibility and historical validation. + +## Generate Rules + +Run the MySQL generator from the repository root: + +```sh +./mvnw -q compile exec:java@mysql-tester +``` + +Generated SQL files are written to: + +- `src/main/java/org/qed/Backends/MySQL/Generated/*.sql` + +MySQL test templates are stored in: + +- `src/main/java/org/qed/Backends/MySQL/Tests/*Test.sql` \ No newline at end of file diff --git a/src/main/java/org/qed/Backends/ProxySQL/README.md b/src/main/java/org/qed/Backends/ProxySQL/README.md new file mode 100644 index 0000000..e78fd2f --- /dev/null +++ b/src/main/java/org/qed/Backends/ProxySQL/README.md @@ -0,0 +1,29 @@ +# ProxySQL + +This directory contains ProxySQL-specific rewrite artifacts and test helpers. + +## Status + +This backend is currently legacy and is retained for compatibility and prior experiments. + +## Generated and test files + +Generated SQL rules are in: + +- `src/main/java/org/qed/Backends/ProxySQL/Generated/*.sql` + +Test SQL and helper scripts are in: + +- `src/main/java/org/qed/Backends/ProxySQL/Tests/*Test.sql` +- `src/main/java/org/qed/Backends/ProxySQL/Tests/script-proxysql.sh` + +## Running ProxySQL tests + +Use the checked-in generated SQL and run the helper script: + +```sh +cd src/main/java/org/qed/Backends/ProxySQL/Tests +bash script-proxysql.sh +``` + +The script expects local ProxySQL/MySQL endpoints and credentials defined at the top of the file. \ No newline at end of file