diff --git a/docs/dialect/aggregation-operators.rst b/docs/dialect/aggregation-operators.rst index f34d1c0..69d7f65 100644 --- a/docs/dialect/aggregation-operators.rst +++ b/docs/dialect/aggregation-operators.rst @@ -401,6 +401,15 @@ Notes - MERGE is an aggregate operation that processes all matching rows - The operation sorts data internally, so pre-sorting is not required +.. note:: + + CLUSTER and MERGE cannot be combined in a single ``SELECT`` — MERGE + aggregates rows away while CLUSTER is a per-row window over those same rows, + so no single query expresses both. Transpiling ``SELECT MERGE(interval), + CLUSTER(interval) FROM features`` raises a ``ValueError``. Use them in + separate queries instead — for example, CLUSTER over a subquery, or MERGE + over one. + Related Operators ~~~~~~~~~~~~~~~~~ diff --git a/src/giql/expander.py b/src/giql/expander.py index ae9ca3e..bab00fb 100644 --- a/src/giql/expander.py +++ b/src/giql/expander.py @@ -102,9 +102,15 @@ class ExpansionContext: ``capabilities`` and ``sqlglot_dialect``. tables : Tables The registered :class:`~giql.table.Tables` container. + registry : ExpanderRegistry | None + The registry the pass is resolving against. Carried so a whole-query + rewrite (CLUSTER / MERGE) can re-enter :func:`expand_operators` over the + SELECT it just restructured and expand sibling operators it copied into + it, honoring a custom-registry pass run. ``None`` for a standalone + context built outside the pass. """ - __slots__ = ("node", "resolution", "target", "tables", "_alias_seq") + __slots__ = ("node", "resolution", "target", "tables", "registry", "_alias_seq") def __init__( self, @@ -113,11 +119,13 @@ def __init__( target: Target, tables: Tables, alias_seq: Callable[[], str] | None = None, + registry: ExpanderRegistry | None = None, ) -> None: self.node = node self.resolution = resolution self.target = target self.tables = tables + self.registry = registry # A single sequence is threaded across every context built for one # ``ExpandOperators`` run so aliases minted for sibling operators never # collide; a standalone context falls back to its own sequence. @@ -514,7 +522,7 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for "valid resolution metadata; pass 1 (resolve_operator_refs) must " "run first and annotate every operator node." ) - ctx = ExpansionContext(node, resolution, target, tables, alias_seq) + ctx = ExpansionContext(node, resolution, target, tables, alias_seq, registry=reg) replacement = fn(node, ctx) if not isinstance(replacement, exp.Expression): raise TypeError( diff --git a/src/giql/expanders/cluster.py b/src/giql/expanders/cluster.py new file mode 100644 index 0000000..6285f82 --- /dev/null +++ b/src/giql/expanders/cluster.py @@ -0,0 +1,564 @@ +"""The CLUSTER operator expander (epic #137, issue #144). + +CLUSTER assigns a cluster id to every interval, grouping each run of mutually +adjacent intervals (optionally within a maximum gap, strand-aware, and gated by a +pairwise predicate) under one id. It cannot be a single window function because it +needs a window over a window (``SUM`` over a ``LAG``-derived flag), so the +expansion restructures the enclosing query into a two-level form:: + + SELECT *, CLUSTER(interval) AS cluster_id FROM features + +becomes:: + + SELECT *, SUM(is_new_cluster) OVER (PARTITION BY chrom ORDER BY start) AS cluster_id + FROM ( + SELECT *, + CASE WHEN LAG(end) OVER (...) >= start THEN 0 ELSE 1 END AS is_new_cluster + FROM features + ) AS lag_calc + +(The ``becomes::`` form is simplified for readability; the emitted SQL quotes +identifiers and appends ``NULLS LAST`` to the window ORDER BY.) + +This module is the AST-expansion replacement for the legacy +:class:`giql.transformer.ClusterTransformer`, which ran as a pre-pass transformer +on the raw parsed AST. It produces the same SQL; the existing CLUSTER +transpilation and bedtools oracle tests are the migration oracle. + +Unlike the node-local expanders for DISTANCE / NEAREST / DISJOIN — whose result +replaces a single node — CLUSTER is a **whole-query rewrite**. It shares only the +*shape* of :func:`giql.expanders.nearest._fallback_form` — return the operator +node unchanged so the pass's ``node.replace`` is a no-op — but the *mechanism* +differs. ``nearest`` does an in-place child ``.replace`` because its LATERAL has a +parent to replace through; CLUSTER instead copies the enclosing +:class:`~sqlglot.expressions.Select`, restructures the copy, and *transplants* its +contents back onto the original SELECT (see :func:`transplant`). That root- +preserving transplant is required because the canonical +``SELECT *, CLUSTER(...) FROM t`` puts CLUSTER at the *root* ``SELECT``, which has +no parent to replace it through; transplanting preserves the root's identity (the +pass returns the same expression object it was handed). + +The pass walks the tree and collects every operator node, expanding deepest-first, +so this expander never recurses into CTEs / subqueries: a nested CLUSTER is +collected and expanded on its own. The shared cluster-restructure helper +(:func:`expand_cluster_query`) is reused by :mod:`giql.expanders.merge`, since +MERGE is built on CLUSTER. +""" + +from __future__ import annotations + +from typing import NamedTuple +from typing import TypeVar + +from sqlglot import exp + +from giql.constants import DEFAULT_CHROM_COL +from giql.constants import DEFAULT_END_COL +from giql.constants import DEFAULT_START_COL +from giql.constants import DEFAULT_STRAND_COL +from giql.expander import ExpansionContext +from giql.expander import expand_operators +from giql.expander import register +from giql.expressions import GIQLCluster +from giql.expressions import GIQLMerge +from giql.table import Tables +from giql.targets import GenericTarget + +_T = TypeVar("_T", bound=exp.Expression) + + +class GenomicColumns(NamedTuple): + """The resolved physical column names CLUSTER / MERGE operate over. + + Derived from the enclosing FROM table by :func:`genomic_columns`. A + :class:`~typing.NamedTuple` so it still unpacks and indexes positionally + (``chrom, start, end, strand = columns``) while giving the four fields names. + """ + + chrom: str + start: str + end: str + strand: str + + +@register(GenericTarget, GIQLCluster) +def expand_cluster(node: GIQLCluster, ctx: ExpansionContext) -> exp.Expression: + """Expand a CLUSTER node by restructuring its enclosing SELECT in place. + + Registered for :class:`~giql.targets.GenericTarget`, so every target resolves + to it through the registry's generic chain (CLUSTER emits identical SQL across + targets). Locates the enclosing :class:`~sqlglot.expressions.Select`, derives + the genomic columns from the FROM table, and rewrites that SELECT in place into + the two-level ``lag_calc`` form. Returns the original CLUSTER node (now + unreachable from the rewritten root SELECT) so the pass's ``node.replace`` is a + no-op. + + Parameters + ---------- + node : GIQLCluster + The CLUSTER node being expanded. + ctx : ExpansionContext + The expansion context; only its :attr:`~ExpansionContext.tables` is read + (CLUSTER derives its columns from the enclosing FROM table, not from + resolution metadata). + + Returns + ------- + exp.Expression + The CLUSTER node, unchanged — the surrounding SELECT was mutated in place. + """ + select = node.find_ancestor(exp.Select) + if select is None: + # Defensive: CLUSTER only ever parses inside a SELECT projection. With no + # enclosing SELECT there is nothing to restructure; leave the node for the + # generator to error on, exactly as the legacy transformer's + # non-Select guard did. + return node + require_top_level_projection(select, node, GIQLCluster) + reject_cluster_merge_mix(select) + columns = genomic_columns(select, ctx.tables) + # Build the transformed query from a detached copy so the intermediate never + # aliases live nodes, then transplant its args onto the original SELECT to + # preserve its identity (and so a root SELECT is rewritten without a parent to + # replace it through). + source = select.copy() + transformed = expand_cluster_query(source, columns) + if transformed is None: + # No-op for a CLUSTER that parses outside the SELECT projection (e.g. in + # WHERE / ORDER BY): find_projected finds none in the copy, so there is + # nothing to restructure and the operator leaks to the generator exactly + # as on `main`. require_top_level_projection above already rejects the + # in-projection-expression case; this guards the out-of-projection case. + return node + transplant(select, transformed) + # copy()+transplant duplicated the enclosing WHERE / HAVING into the new + # lag_calc subquery; the originals the pass collected are now unreachable, so + # re-run the pass over the restructured SELECT to expand any sibling pass-3 + # operators (spatial predicates, DISTANCE) carried into it. Safe from + # recursion: the CLUSTER node is already replaced by its SUM window. (#144 B1) + expand_operators(select, ctx.target, ctx.tables, ctx.registry) + return node + + +def genomic_columns(select: exp.Select, tables: Tables) -> GenomicColumns: + """Return the ``(chrom, start, end, strand)`` columns for *select*'s FROM table. + + Part of the shared CLUSTER/MERGE expansion toolkit (reused by + :mod:`giql.expanders.merge`). Mirrors the legacy + ``ClusterTransformer._get_genomic_columns`` / ``_get_table_name``: read the + FROM-clause table name, look it up in *tables*, and use its configured column + names, falling back to the canonical defaults (and to the default strand + column when the table declares none). + """ + table_name: str | None = None + from_clause = select.args.get("from_") + if from_clause is not None and isinstance(from_clause.this, exp.Table): + table_name = from_clause.this.name + + chrom_col = DEFAULT_CHROM_COL + start_col = DEFAULT_START_COL + end_col = DEFAULT_END_COL + strand_col = DEFAULT_STRAND_COL + + if table_name: + table = tables.get(table_name) + if table: + chrom_col = table.chrom_col + start_col = table.start_col + end_col = table.end_col + if table.strand_col: + strand_col = table.strand_col + + return GenomicColumns(chrom_col, start_col, end_col, strand_col) + + +def extract_stranded(stranded_expr: exp.Expression | None) -> bool: + """Coerce a CLUSTER/MERGE ``stranded`` operand to a bool. Shared toolkit. + + Mirrors the legacy per-transformer coercion exactly: a missing operand is + ``False``; an ``exp.Boolean`` yields its raw ``.this``; an ``exp.Literal`` + compares case-folded to ``TRUE``. The final two arms (``exp.Literal`` and the + string-truthiness fallback) are defensive — the GIQL grammar only ever produces + ``exp.Boolean`` for ``stranded := `` — retained for parity with the + legacy port. + """ + if stranded_expr is None: + return False + if isinstance(stranded_expr, exp.Boolean): + return stranded_expr.this + if isinstance(stranded_expr, exp.Literal): + return str(stranded_expr.this).upper() == "TRUE" + return str(stranded_expr).upper() in ("TRUE", "1", "YES") + + +def expand_cluster_query( + query: exp.Select, columns: GenomicColumns +) -> exp.Select | None: + """Restructure *query* for every CLUSTER in its projection; ``None`` if none. + + Finds CLUSTER expressions in *query*'s SELECT list and rewrites the query into + the two-level ``lag_calc`` form once per CLUSTER (chaining, as the legacy + transformer did). Operates on *query* directly and returns the rewritten + query. Returns ``None`` when *query* projects no CLUSTER, so callers can treat + that as a no-op. + + Reused by :mod:`giql.expanders.merge`, whose intermediate clustered query is + restructured through this same helper. + """ + cluster_exprs = find_projected(query, GIQLCluster) + if not cluster_exprs: + return None + if len(cluster_exprs) > 1: + # Chaining the rewrite per CLUSTER yields a duplicate ``lag_calc`` alias + # and an ``is_new_cluster`` binder error — non-executable SQL. Fail loudly, + # mirroring the multiple-MERGE guard, rather than emitting it (#144 A15). + raise ValueError("Multiple CLUSTER expressions not yet supported") + for cluster_expr in cluster_exprs: + query = _transform_for_cluster(query, cluster_expr, columns) + return query + + +def find_projected(select: exp.Select, op_type: type[_T]) -> list[_T]: + """Return *select*'s projected operators of *op_type* (bare or aliased). Toolkit. + + Shared CLUSTER/MERGE primitive: both expanders and the co-occurrence guard + locate their operator the same way — a top-level SELECT projection item that + either *is* the operator or is an ``exp.Alias`` wrapping it. + """ + found: list[_T] = [] + for expression in select.expressions: + if isinstance(expression, op_type): + found.append(expression) + elif isinstance(expression, exp.Alias) and isinstance( + expression.this, op_type + ): + found.append(expression.this) + return found + + +def require_top_level_projection( + select: exp.Select, node: exp.Expression, op_type: type +) -> None: + """Raise if *node* is buried inside a projection expression. Shared toolkit. + + A CLUSTER / MERGE is only expandable as a *top-level* projection item — bare + or directly aliased — because the whole-query rewrite restructures the SELECT + around it. One nested inside a larger projection expression such as + ``ABS(CLUSTER(interval))`` has no coherent rewrite and would otherwise leak an + unexpanded operator to the generator, so fail loudly here (#144 A16). An + operator that parses *outside* the projection entirely (e.g. in WHERE / + ORDER BY) is not under any projection item and is left for the expander's + existing no-op path. + """ + operator = op_type.__name__.removeprefix("GIQL").upper() + for projection in select.expressions: + inner = projection.this if isinstance(projection, exp.Alias) else projection + if inner is node: + return + if any(descendant is node for descendant in inner.walk()): + raise ValueError( + f"{operator} must be a top-level projection item; it cannot be " + "nested inside another expression (e.g. a function call or " + "arithmetic)." + ) + + +def reject_cluster_merge_mix(select: exp.Select) -> None: + """Raise if *select* projects both a CLUSTER and a MERGE. Shared toolkit. + + The two are mutually incompatible in one SELECT: MERGE aggregates the rows + away (it rewrites the query into a ``GROUP BY`` over a clustered subquery) + while CLUSTER is a per-row window over those same rows, so no coherent single + query expresses both. The legacy pre-pass chained the two transformers and + emitted *non-executable* SQL for this shape (a window over ``GROUP BY``- + aggregated rows — a DuckDB ``BinderException``, never a leaked operator). The + new in-place ``transplant`` cannot express both at all: whichever expander + runs first rewrites the shared SELECT and strands the sibling as an unexpanded + node. Fail loudly here — mirroring the ``Multiple MERGE expressions not yet + supported`` guard — so the combination raises a clear diagnostic rather than + emitting broken SQL. + """ + if find_projected(select, GIQLCluster) and find_projected(select, GIQLMerge): + raise ValueError( + "CLUSTER and MERGE cannot be combined in a single SELECT; MERGE " + "aggregates rows while CLUSTER is a per-row window. Use them in " + "separate queries (e.g. CLUSTER over a subquery, or MERGE over one)." + ) + + +def _transform_for_cluster( + query: exp.Select, cluster_expr: GIQLCluster, columns: GenomicColumns +) -> exp.Select: + """Rewrite *query* into the two-level ``lag_calc`` form for one CLUSTER. + + A byte-for-byte port of the legacy + ``ClusterTransformer._transform_for_cluster``, with the genomic columns passed + in (the legacy method re-derived them from the FROM table) rather than read off + ``self``. Builds an inner ``lag_calc`` subquery that materializes an + ``is_new_cluster`` flag from a ``LAG`` window, then an outer query whose + ``SUM(is_new_cluster) OVER (...)`` window replaces the CLUSTER projection. + """ + chrom_col, start_col, end_col, strand_col = columns + + # Extract CLUSTER parameters + distance_expr = cluster_expr.args.get("distance") + + # Handle distance parameter - could be int literal or None + if distance_expr: + if isinstance(distance_expr, exp.Literal): + distance = int(distance_expr.this) + else: + # Defensive: the grammar only yields exp.Literal for a distance + # argument, so this non-Literal fallback is unreachable via the parser; + # retained for parity with the legacy port. + try: + distance = int(str(distance_expr.this)) + except (ValueError, AttributeError): + distance = 0 + else: + distance = 0 + + stranded = extract_stranded(cluster_expr.args.get("stranded")) + + # Build partition clause + partition_cols = [exp.column(chrom_col, quoted=True)] + if stranded: + partition_cols.append(exp.column(strand_col, quoted=True)) + + # Build ORDER BY for window + order_by = [exp.Ordered(this=exp.column(start_col, quoted=True))] + + # Create LAG window spec + lag_window = exp.Window( + this=exp.Anonymous(this="LAG", expressions=[exp.column(end_col, quoted=True)]), + partition_by=partition_cols, + order=exp.Order(expressions=order_by), + ) + + # Add distance offset if specified + if distance > 0: + lag_with_distance = exp.Add( + this=lag_window, expression=exp.Literal.number(distance) + ) + else: + lag_with_distance = lag_window + + # Build the adjacency condition (predecessor end >= current start). + adjacency = exp.GTE( + this=lag_with_distance, + expression=exp.column(start_col, quoted=True), + ) + + # An optional predicate further restricts which adjacent intervals + # are kept together: a row stays in the current cluster only when it + # is adjacent to its predecessor AND the predicate holds between them. + # ``PREV(col)`` references in the predicate resolve to the predecessor + # row via LAG over the same partition/order as the adjacency window. + predicate_expr = cluster_expr.args.get("predicate") + if predicate_expr is not None: + rewritten_predicate = _rewrite_predecessor_refs( + predicate_expr, partition_cols, order_by + ) + keep_together = exp.And( + this=adjacency, + expression=exp.Paren(this=rewritten_predicate), + ) + else: + keep_together = adjacency + + # Create CASE expression for is_new_cluster + case_expr = exp.Case( + ifs=[ + exp.If( + this=keep_together, + true=exp.Literal.number(0), + ) + ], + default=exp.Literal.number(1), + ) + + # Build CTE SELECT expressions (all original except CLUSTER, plus is_new_cluster) + cte_expressions = [] + for expression in query.expressions: + # Skip CLUSTER expressions + if isinstance(expression, GIQLCluster): + continue + elif isinstance(expression, exp.Alias) and isinstance( + expression.this, GIQLCluster + ): + continue + else: + cte_expressions.append(expression) + + # Ensure required columns for window functions are included + required_cols = {chrom_col, start_col, end_col} + if stranded: + required_cols.add(strand_col) + + # The predicate is evaluated inside the lag_calc CTE, so every column + # it references (current-row columns and PREV() arguments alike) must + # be projected into that CTE. Folding them into required_cols makes the + # scope dependency explicit and keeps the columns available even when a + # later operator wraps this query in a further subquery. + if predicate_expr is not None: + required_cols |= {col.name for col in predicate_expr.find_all(exp.Column)} + + # Check if required columns are already in the select list + selected_cols = set() + for expr in cte_expressions: + if isinstance(expr, exp.Column): + selected_cols.add(expr.name) + elif isinstance(expr, exp.Alias): + # Don't count aliases as the source column + pass + elif isinstance(expr, exp.Star): + # SELECT * includes all columns + selected_cols = required_cols # Assume all are covered + break + + # Add missing required columns + # Sort the residual so the injected-column order is deterministic across runs + # (set-difference iteration order is PYTHONHASHSEED-dependent; the legacy port + # left it nondeterministic — see review #144 A2). + for col in sorted(required_cols - selected_cols): + cte_expressions.append(exp.column(col, quoted=True)) + + # Add is_new_cluster calculation + # NOTE: synthesized name; the missing __giql_ reserved prefix is left to #161. + cte_expressions.append(exp.alias_(case_expr, "is_new_cluster", quoted=False)) + + # Build CTE query + cte_select = exp.Select() + cte_select.select(*cte_expressions, copy=False) + + # Copy FROM, WHERE, GROUP BY, HAVING from original (but not ORDER BY) + # Use copy() to avoid sharing references between queries + if query.args.get("from_"): + from_clause = query.args["from_"].copy() + cte_select.set("from_", from_clause) + if query.args.get("where"): + cte_select.set("where", query.args["where"].copy()) + if query.args.get("group"): + cte_select.set("group", query.args["group"].copy()) + if query.args.get("having"): + cte_select.set("having", query.args["having"].copy()) + + # Create outer query with SUM over is_new_cluster + sum_window = exp.Window( + this=exp.Sum(this=exp.column("is_new_cluster")), + partition_by=partition_cols, + order=exp.Order(expressions=order_by), + ) + + # Build outer SELECT expressions (replace CLUSTER with SUM) + new_expressions = [] + for expression in query.expressions: + if isinstance(expression, GIQLCluster): + new_expressions.append(sum_window) + elif isinstance(expression, exp.Alias) and isinstance( + expression.this, GIQLCluster + ): + # Keep the alias but replace the expression + new_expressions.append( + exp.alias_(sum_window, expression.alias, quoted=False) + ) + else: + new_expressions.append(expression) + + # Build new query + new_query = exp.Select() + new_query.select(*new_expressions, copy=False) + + # Wrap CTE in subquery and set as FROM clause + # NOTE: synthesized name; the missing __giql_ reserved prefix is left to #161. + subquery = exp.Subquery( + this=cte_select, + alias=exp.TableAlias(this=exp.Identifier(this="lag_calc")), + ) + new_query.from_(subquery, copy=False) + + # Copy ORDER BY from original to outer query + if query.args.get("order"): + new_query.order_by(*query.args["order"].expressions, copy=False) + + return new_query + + +def _rewrite_predecessor_refs( + predicate: exp.Expression, + partition_cols: list[exp.Expression], + order_by: list[exp.Ordered], +) -> exp.Expression: + """Rewrite ``PREV(col)`` calls in a predicate to LAG windows. + + A byte-for-byte port of the legacy + ``ClusterTransformer._rewrite_predecessor_refs``. Bare column references in the + predicate denote the current interval. Each ``PREV(col)`` call denotes the + sorted predecessor's value of that column and is rewritten to ``LAG(col) OVER + (...)`` using the same partition/order as the cluster's adjacency window, so + the predicate is evaluated pairwise against the immediately preceding row. + Every column identifier (current-row columns and LAG arguments alike) is quoted + so that reserved-word genomic columns such as ``start`` / ``end`` are emitted as + valid SQL. + + :param predicate: + Boolean predicate expression to rewrite (not mutated). + :param partition_cols: + Window partition columns (chromosome, optionally strand). + :param order_by: + Window ORDER BY terms (start position). + :return: + A copy of the predicate with every ``PREV(...)`` call replaced by an + equivalent LAG window and all column identifiers quoted. + :raises ValueError: + If a ``PREV()`` call does not take exactly one argument, or if a + ``PREV()`` call is nested inside another (predicates compare only the + immediate predecessor). + """ + + def _is_prev(node: exp.Expression) -> bool: + return isinstance(node, exp.Anonymous) and node.name.upper() == "PREV" + + def _replace(node: exp.Expression) -> exp.Expression: + if _is_prev(node): + args = node.expressions + if len(args) != 1: + raise ValueError( + f"PREV() takes exactly one column argument; got {len(args)}." + ) + if any(_is_prev(inner) for inner in args[0].find_all(exp.Anonymous)): + raise ValueError( + "PREV() cannot be nested; a CLUSTER/MERGE predicate " + "compares only the immediate predecessor." + ) + return exp.Window( + this=exp.Anonymous(this="LAG", expressions=[args[0].copy()]), + partition_by=[col.copy() for col in partition_cols], + order=exp.Order(expressions=[term.copy() for term in order_by]), + ) + return node + + rewritten = predicate.copy().transform(_replace) + for column in rewritten.find_all(exp.Column): + column.this.set("quoted", True) + return rewritten + + +def transplant(select: exp.Select, new: exp.Select) -> None: + """Replace *select*'s contents with *new*'s, preserving *select*'s identity. + + Part of the shared CLUSTER/MERGE expansion toolkit. Clears every argument of + *select* and re-installs *new*'s, so *select* keeps its position in the + surrounding tree (and its identity as the object the pass returns) while taking + on the rewritten structure. This is how a whole-query rewrite is applied to a + *root* SELECT, which has no parent to ``replace`` through. + + Precondition: *new* MUST be a detached throwaway — a freshly-built + ``exp.Select``, as the ``_transform_for_*`` helpers return. Its children are + re-parented onto *select*, so passing a node still attached elsewhere would + corrupt that other tree. + """ + assert new.parent is None, "transplant() requires a detached `new` subtree" + select.args.clear() + for key, value in list(new.args.items()): + select.set(key, value) diff --git a/src/giql/expanders/merge.py b/src/giql/expanders/merge.py new file mode 100644 index 0000000..4191635 --- /dev/null +++ b/src/giql/expanders/merge.py @@ -0,0 +1,245 @@ +"""The MERGE operator expander (epic #137, issue #144). + +MERGE combines overlapping (and, with parameters, adjacent / strand-matched / +predicate-gated) intervals into single intervals. It is built on CLUSTER: assign +a cluster id, then aggregate ``MIN(start)`` / ``MAX(end)`` per cluster:: + + SELECT MERGE(interval) FROM features + +becomes:: + + SELECT chrom, MIN(start) AS start, MAX(end) AS end + FROM (SELECT *, CLUSTER(interval) AS __giql_cluster_id FROM features) AS clustered + GROUP BY chrom, __giql_cluster_id + ORDER BY chrom, start + +(The ``becomes::`` form is simplified for readability; the emitted SQL quotes +identifiers, appends ``NULLS LAST``, and the inner ``CLUSTER(...)`` is itself +expanded into the two-level ``lag_calc`` form.) + +This module is the AST-expansion replacement for the legacy +:class:`giql.transformer.MergeTransformer`; it produces the same SQL (the existing +MERGE transpilation and bedtools oracle tests are the migration oracle). + +Like CLUSTER (:mod:`giql.expanders.cluster`), MERGE is a **whole-query rewrite**: +it navigates to the enclosing :class:`~sqlglot.expressions.Select`, mutates it in +place, and returns the operator node unchanged so the pass's ``node.replace`` is a +no-op. The intermediate clustered subquery is restructured through CLUSTER's shared +:func:`giql.expanders.cluster.expand_cluster_query`, mirroring how the legacy +``MergeTransformer`` composed ``ClusterTransformer``. +""" + +from __future__ import annotations + +from sqlglot import exp + +from giql.expander import ExpansionContext +from giql.expander import expand_operators +from giql.expander import register + +# Shared CLUSTER/MERGE expansion toolkit (MERGE is built on CLUSTER). +from giql.expanders.cluster import GenomicColumns +from giql.expanders.cluster import expand_cluster_query +from giql.expanders.cluster import extract_stranded +from giql.expanders.cluster import find_projected +from giql.expanders.cluster import genomic_columns +from giql.expanders.cluster import reject_cluster_merge_mix +from giql.expanders.cluster import require_top_level_projection +from giql.expanders.cluster import transplant +from giql.expressions import GIQLCluster +from giql.expressions import GIQLMerge +from giql.targets import GenericTarget + + +@register(GenericTarget, GIQLMerge) +def expand_merge(node: GIQLMerge, ctx: ExpansionContext) -> exp.Expression: + """Expand a MERGE node by restructuring its enclosing SELECT in place. + + Registered for :class:`~giql.targets.GenericTarget` (MERGE emits identical SQL + across targets). Locates the enclosing + :class:`~sqlglot.expressions.Select`, derives the genomic columns from the FROM + table, and rewrites that SELECT in place into the clustered-aggregation form. + Returns the original MERGE node (now unreachable from the rewritten root SELECT) + so the pass's ``node.replace`` is a no-op. + + Parameters + ---------- + node : GIQLMerge + The MERGE node being expanded. + ctx : ExpansionContext + The expansion context; only its :attr:`~ExpansionContext.tables` is read. + + Returns + ------- + exp.Expression + The MERGE node, unchanged — the surrounding SELECT was mutated in place. + """ + select = node.find_ancestor(exp.Select) + if select is None: + return node + require_top_level_projection(select, node, GIQLMerge) + reject_cluster_merge_mix(select) + columns = genomic_columns(select, ctx.tables) + # Build from a detached copy, then transplant onto the original SELECT to + # preserve its identity (and rewrite a root SELECT without a parent to replace + # through), exactly as the CLUSTER expander does. + source = select.copy() + transformed = _transform_select_merge(source, columns) + if transformed is None: + # No-op for a MERGE that parses outside the SELECT projection (e.g. in + # WHERE / ORDER BY): find_projected finds none in the copy, so there is + # nothing to restructure and the operator leaks to the generator exactly + # as on `main`. require_top_level_projection above already rejects the + # in-projection-expression case; this guards the out-of-projection case. + return node + transplant(select, transformed) + # copy()+transplant duplicated the enclosing WHERE into the new clustered + # subquery; the originals the pass collected are now unreachable, so re-run the + # pass over the restructured SELECT to expand any sibling pass-3 operators + # carried into it. Safe from recursion: the MERGE is already gone. (#144 B1) + expand_operators(select, ctx.target, ctx.tables, ctx.registry) + return node + + +def _transform_select_merge( + query: exp.Select, columns: GenomicColumns +) -> exp.Select | None: + """Rewrite *query* for the MERGE in its projection; ``None`` if none. + + Mirrors the legacy ``MergeTransformer.transform`` dispatch: find the MERGE + expressions, reject more than one, and delegate the single supported case to + :func:`_transform_for_merge`. + """ + merge_exprs = find_projected(query, GIQLMerge) + if not merge_exprs: + return None + # For now, support only one MERGE expression + if len(merge_exprs) > 1: + raise ValueError("Multiple MERGE expressions not yet supported") + return _transform_for_merge(query, merge_exprs[0], columns) + + +def _transform_for_merge( + query: exp.Select, merge_expr: GIQLMerge, columns: GenomicColumns +) -> exp.Select: + """Rewrite *query* into the clustered-aggregation form for one MERGE. + + A byte-for-byte port of the legacy ``MergeTransformer._transform_for_merge``, + with the genomic columns passed in and the intermediate clustered query + restructured through CLUSTER's shared + :func:`giql.expanders.cluster.expand_cluster_query` (the legacy method called + ``ClusterTransformer.transform``). Builds an inner ``clustered`` subquery that + appends ``__giql_cluster_id``, then an outer query that aggregates + ``MIN(start)`` / ``MAX(end)`` per cluster. + """ + chrom_col, start_col, end_col, strand_col = columns + + # Extract MERGE parameters (same as CLUSTER) + distance_expr = merge_expr.args.get("distance") + stranded_expr = merge_expr.args.get("stranded") + predicate_expr = merge_expr.args.get("predicate") + + # Build CLUSTER expression with same parameters + cluster_kwargs = {"this": merge_expr.this} + if distance_expr: + cluster_kwargs["distance"] = distance_expr + if stranded_expr: + cluster_kwargs["stranded"] = stranded_expr + if predicate_expr is not None: + cluster_kwargs["predicate"] = predicate_expr + + cluster_expr = GIQLCluster(**cluster_kwargs) + + # Create intermediate query with CLUSTER + # Start with original query's FROM/WHERE/etc + cluster_query = exp.Select() + cluster_query.select(exp.Star(), copy=False) + # NOTE: __giql_cluster_id carries the reserved prefix; the un-prefixed + # `clustered` / `lag_calc` / `is_new_cluster` siblings are left to #161. + cluster_query.select( + exp.alias_(cluster_expr, "__giql_cluster_id", quoted=False), + append=True, + copy=False, + ) + + # Copy FROM, WHERE from original + # Use copy() to avoid sharing references between queries + if query.args.get("from_"): + cluster_query.set("from_", query.args["from_"].copy()) + if query.args.get("where"): + cluster_query.set("where", query.args["where"].copy()) + + # Apply CLUSTER transformation to get the CTE-based query + clustered = expand_cluster_query(cluster_query, columns) + assert clustered is not None, "intermediate MERGE cluster query has no CLUSTER" + cluster_query = clustered + + # Build GROUP BY columns + # Quote the chrom column (as every other chrom reference is) so a reserved-word + # custom chrom column emits valid SQL (#144 A13). + group_by_cols = [exp.column(chrom_col, quoted=True)] + + # Handle stranded parameter + stranded = extract_stranded(stranded_expr) + + if stranded: + group_by_cols.append(exp.column(strand_col, quoted=True)) + + group_by_cols.append(exp.column("__giql_cluster_id")) + + # Build SELECT expressions for merged intervals + select_exprs = [] + + # Add group-by columns (non-aggregated) + select_exprs.append(exp.column(chrom_col, quoted=True)) + if stranded: + select_exprs.append(exp.column(strand_col, quoted=True)) + + # Add merged interval bounds + select_exprs.append( + exp.alias_( + exp.Min(this=exp.column(start_col, quoted=True)), start_col, quoted=False + ) + ) + select_exprs.append( + exp.alias_( + exp.Max(this=exp.column(end_col, quoted=True)), end_col, quoted=False + ) + ) + + # Process other columns from original SELECT + for expression in query.expressions: + # Skip the MERGE expression itself + if isinstance(expression, GIQLMerge): + continue + elif isinstance(expression, exp.Alias) and isinstance( + expression.this, GIQLMerge + ): + continue + # Include other columns (they should be aggregates or in GROUP BY) + else: + select_exprs.append(expression) + + # Build final query + final_query = exp.Select() + final_query.select(*select_exprs, copy=False) + + # FROM the clustered subquery + subquery = exp.Subquery( + this=cluster_query, + alias=exp.TableAlias(this=exp.Identifier(this="clustered")), + ) + final_query.from_(subquery, copy=False) + + # Add GROUP BY + final_query.group_by(*group_by_cols, copy=False) + + # Add ORDER BY (chromosome, start) + final_query.order_by( + exp.Ordered(this=exp.column(chrom_col, quoted=True)), copy=False + ) + final_query.order_by( + exp.Ordered(this=exp.column(start_col, quoted=True)), append=True, copy=False + ) + + return final_query diff --git a/src/giql/expressions.py b/src/giql/expressions.py index 6a0f440..ffbf017 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -110,15 +110,6 @@ class SpatialPredicate(exp.Binary): #: half-open) operands are left untouched and the emitted SQL stays byte-identical. _CANONICALIZE = True -#: Default opt-out from the ``ExpandOperators`` pass (epic #137, step 2). The -#: per-operator ``GIQL_EXPAND`` flag mirrors ``GIQL_CANONICALIZE``: an operator -#: takes the new AST-expansion path only when it sets ``GIQL_EXPAND = True`` *and* -#: an expander is registered for it; otherwise the legacy ``*_sql`` emitter runs. -#: This is the opt-out default: an operator inherits it and stays on the legacy -#: emitter until its migration step flips the flag to ``True`` alongside its -#: registered expander. Operators already migrated override it on their own class. -_EXPAND = False - class Intersects(SpatialPredicate): """INTERSECTS spatial predicate. @@ -244,11 +235,14 @@ class GIQLCluster(exp.Func): "predicate": False, # pairwise boolean gate (current row vs PREV(col)) } - # Inert today: the CLUSTER/MERGE transformers rewrite these nodes before the - # ExpandOperators pass runs, so the pass never sees a GIQLCluster to dispatch - # and this flag is not a live opt-in. It is forward-looking for #144, which - # migrates these operators onto the expander registry. - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators pass (epic #137, issue #144). CLUSTER is + #: expanded by ``giql.expanders.cluster`` — a whole-query rewrite into the + #: two-level ``lag_calc`` form — replacing the legacy + #: ``giql.transformer.ClusterTransformer`` pre-pass transformer. Note CLUSTER + #: deliberately does NOT set ``GIQL_CANONICALIZE``: the expander derives its + #: columns from the FROM table, so pass 2 is intentionally a no-op here and the + #: emitted SQL stays byte-identical to the legacy pre-pass output. + GIQL_EXPAND = True @classmethod def from_arg_list(cls, args): @@ -291,11 +285,14 @@ class GIQLMerge(exp.Func): "predicate": False, # pairwise boolean gate (current row vs PREV(col)) } - # Inert today: the CLUSTER/MERGE transformers rewrite these nodes before the - # ExpandOperators pass runs, so the pass never sees a GIQLMerge to dispatch - # and this flag is not a live opt-in. It is forward-looking for #144, which - # migrates these operators onto the expander registry. - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators pass (epic #137, issue #144). MERGE is + #: expanded by ``giql.expanders.merge`` — a whole-query rewrite into the + #: clustered-aggregation form (built on CLUSTER) — replacing the legacy + #: ``giql.transformer.MergeTransformer`` pre-pass transformer. Like CLUSTER, + #: MERGE deliberately does NOT set ``GIQL_CANONICALIZE`` (columns come from the + #: FROM table), so pass 2 is intentionally a no-op and the SQL stays + #: byte-identical to the legacy pre-pass output. + GIQL_EXPAND = True @classmethod def from_arg_list(cls, args): diff --git a/src/giql/resolver.py b/src/giql/resolver.py index 7869d28..8431255 100644 --- a/src/giql/resolver.py +++ b/src/giql/resolver.py @@ -22,18 +22,20 @@ asserts every operator slot carries well-formed resolution metadata, mirroring ``sqlglot``'s ``validate_qualify_columns`` and Spark's ``CheckAnalysis``. -Scope note (epic #114, steps 1-3) ---------------------------------- -The pass is behavior-preserving. DISJOIN's expander -(``giql.expanders.disjoin``, step 2) and NEAREST's emitter -(``BaseGIQLGenerator.giqlnearest_sql``, step 3) consume the attached metadata; -DISTANCE and the spatial predicates still use the generator's legacy resolver -paths and ignore everything attached here until their port issues land. The -resolution semantics computed here mirror the generator's historical -``_resolve_target_table`` / ``_resolve_disjoin_reference`` / -``_enclosing_cte_names`` (DISJOIN) and ``_resolve_nearest_reference`` / -``_find_outer_table_in_lateral_join`` (NEAREST) behavior exactly; all of those -helpers now live only here. +Scope note (epic #114 / #137) +----------------------------- +The pass is behavior-preserving. Every operator in the ``_OPERATORS`` roster now +takes the ``ExpandOperators`` path (epic #137): the DISJOIN, NEAREST, DISTANCE, +and spatial-predicate (INTERSECTS / CONTAINS / WITHIN / ``SpatialSetPredicate``) +expanders in :mod:`giql.expanders` consume the metadata attached here, resolving +their reference slots and column operands through it. CLUSTER and MERGE (#144) +declare no reference slots, so they resolve to an empty-but-valid +:class:`OperatorResolution` and their expanders derive columns from the enclosing +FROM table rather than from resolution. The resolution semantics computed here +mirror the generator's historical ``_resolve_target_table`` / +``_resolve_disjoin_reference`` / ``_enclosing_cte_names`` (DISJOIN) and +``_resolve_nearest_reference`` / ``_find_outer_table_in_lateral_join`` (NEAREST) +behavior exactly; all of those helpers now live only here. Two consequences of the zero-behavior-change constraint shape the implementation: @@ -75,8 +77,10 @@ from giql.constants import DEFAULT_STRAND_COL from giql.constants import DJ_PREFIX from giql.expressions import Contains +from giql.expressions import GIQLCluster from giql.expressions import GIQLDisjoin from giql.expressions import GIQLDistance +from giql.expressions import GIQLMerge from giql.expressions import GIQLNearest from giql.expressions import Intersects from giql.expressions import SlotSpec @@ -128,11 +132,17 @@ DEFAULT_END_COL, ) -#: The GIQL operator expression classes the pass inspects. +#: The GIQL operator expression classes the pass inspects. CLUSTER and MERGE +#: (#144) declare no reference slots, so they resolve to an empty-but-valid +#: ``OperatorResolution``; the ``ExpandOperators`` pass requires that metadata to +#: be present, and their expanders derive columns from the enclosing FROM table +#: rather than from resolution. _OPERATORS: tuple[type[exp.Expression], ...] = ( GIQLDisjoin, GIQLNearest, GIQLDistance, + GIQLCluster, + GIQLMerge, Intersects, Contains, Within, diff --git a/src/giql/transformer.py b/src/giql/transformer.py index c757aaf..e75d254 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -1,8 +1,9 @@ """Query transformers for GIQL operations. -This module contains transformers that rewrite queries containing GIQL-specific -operations (like CLUSTER, MERGE, and binned INTERSECTS joins) into equivalent -SQL with CTEs. +This module contains the pre-pass transformers that rewrite column-to-column +INTERSECTS joins (the binned equi-join and DuckDB IEJoin plans) into equivalent +SQL with CTEs. CLUSTER and MERGE were relocated to the operator-expander registry +(``giql.expanders.cluster`` / ``giql.expanders.merge``) in epic #137 (#144). """ from dataclasses import dataclass @@ -18,9 +19,6 @@ from giql.constants import DEFAULT_CHROM_COL from giql.constants import DEFAULT_END_COL from giql.constants import DEFAULT_START_COL -from giql.constants import DEFAULT_STRAND_COL -from giql.expressions import GIQLCluster -from giql.expressions import GIQLMerge from giql.expressions import Intersects from giql.table import Table from giql.table import Tables @@ -228,657 +226,6 @@ def from_intersects( ) -class ClusterTransformer: - """Transforms queries containing CLUSTER into CTE-based queries. - - CLUSTER cannot be a simple window function because it requires nested - window functions (LAG inside SUM). Instead, we transform: - - SELECT *, CLUSTER(interval) AS cluster_id FROM features - - Into: - - WITH lag_calc AS ( - SELECT *, LAG(end_pos) OVER (...) AS prev_end FROM features - ) - SELECT *, SUM(CASE WHEN prev_end >= start_pos ...) AS cluster_id - FROM lag_calc - """ - - def __init__(self, tables: Tables): - """Initialize transformer. - - :param tables: - Table configurations for column mapping - """ - self.tables = tables - - def _get_table_name(self, query: exp.Select) -> str | None: - """Extract table name from query's FROM clause. - - :param query: - Query to extract table name from - :return: - Table name if FROM contains a simple table, None otherwise - """ - from_clause = query.args.get("from_") - if not from_clause: - return None - - if isinstance(from_clause.this, exp.Table): - return from_clause.this.name - - return None - - def _get_genomic_columns(self, query: exp.Select) -> tuple[str, str, str, str]: - """Get genomic column names from table config or defaults. - - :param query: - Query to extract table and column info from - :return: - Tuple of (chrom_col, start_col, end_col, strand_col) - """ - table_name = self._get_table_name(query) - - # Default column names - chrom_col = DEFAULT_CHROM_COL - start_col = DEFAULT_START_COL - end_col = DEFAULT_END_COL - strand_col = DEFAULT_STRAND_COL - - if table_name: - table = self.tables.get(table_name) - if table: - chrom_col = table.chrom_col - start_col = table.start_col - end_col = table.end_col - if table.strand_col: - strand_col = table.strand_col - - return chrom_col, start_col, end_col, strand_col - - def transform(self, query: exp.Expression) -> exp.Expression: - """Transform query if it contains CLUSTER expressions. - - :param query: - Parsed query AST - :return: - Transformed query AST - """ - if not isinstance(query, exp.Select): - return query - - # First, recursively transform any CTEs that might contain CLUSTER - if query.args.get("with_"): - cte = query.args["with_"] - for cte_expr in cte.expressions: - if isinstance(cte_expr, exp.CTE): - # Transform the CTE's subquery - cte_expr.set("this", self.transform(cte_expr.this)) - - # Recursively transform subqueries in FROM clause - if query.args.get("from_"): - from_clause = query.args["from_"] - self._transform_subqueries_in_node(from_clause) - - # Recursively transform subqueries in JOIN clauses - if query.args.get("joins"): - for join in query.args["joins"]: - self._transform_subqueries_in_node(join) - - # Recursively transform subqueries in WHERE clause - if query.args.get("where"): - self._transform_subqueries_in_node(query.args["where"]) - - # Find all CLUSTER expressions in the SELECT clause - cluster_exprs = self._find_cluster_expressions(query) - - if not cluster_exprs: - return query - - # Transform query for each CLUSTER expression - for cluster_expr in cluster_exprs: - query = self._transform_for_cluster(query, cluster_expr) - - return query - - def _transform_subqueries_in_node(self, node: exp.Expression): - """Recursively transform subqueries within an expression node. - - :param node: - Expression node to search for subqueries - """ - # Find and transform any Subquery nodes - for subquery in node.find_all(exp.Subquery): - if isinstance(subquery.this, exp.Select): - transformed = self.transform(subquery.this) - subquery.set("this", transformed) - - def _find_cluster_expressions(self, query: exp.Select) -> list[GIQLCluster]: - """Find all CLUSTER expressions in query. - - :param query: - Query to search - :return: - List of CLUSTER expressions - """ - cluster_exprs = [] - - for expression in query.expressions: - # Check if this is a CLUSTER expression or an alias containing one - if isinstance(expression, GIQLCluster): - cluster_exprs.append(expression) - elif isinstance(expression, exp.Alias): - if isinstance(expression.this, GIQLCluster): - cluster_exprs.append(expression.this) - - return cluster_exprs - - def _transform_for_cluster( - self, query: exp.Select, cluster_expr: GIQLCluster - ) -> exp.Select: - """Transform query to compute CLUSTER using CTEs. - - :param query: - Original query - :param cluster_expr: - CLUSTER expression to transform - :return: - Transformed query with CTEs - """ - # Extract CLUSTER parameters - distance_expr = cluster_expr.args.get("distance") - - # Handle distance parameter - could be int literal or None - if distance_expr: - if isinstance(distance_expr, exp.Literal): - distance = int(distance_expr.this) - else: - # Try to extract as string and convert - try: - distance = int(str(distance_expr.this)) - except (ValueError, AttributeError): - distance = 0 - else: - distance = 0 - - stranded_expr = cluster_expr.args.get("stranded") - if stranded_expr: - # Handle different types of boolean expressions - if isinstance(stranded_expr, exp.Boolean): - stranded = stranded_expr.this - elif isinstance(stranded_expr, exp.Literal): - stranded = str(stranded_expr.this).upper() == "TRUE" - else: - # Try to extract the value as a string - stranded = str(stranded_expr).upper() in ("TRUE", "1", "YES") - else: - stranded = False - - # Get column names from table config or use defaults - chrom_col, start_col, end_col, strand_col = self._get_genomic_columns(query) - - # Build partition clause - partition_cols = [exp.column(chrom_col, quoted=True)] - if stranded: - partition_cols.append(exp.column(strand_col, quoted=True)) - - # Build ORDER BY for window - order_by = [exp.Ordered(this=exp.column(start_col, quoted=True))] - - # Create LAG window spec - lag_window = exp.Window( - this=exp.Anonymous( - this="LAG", expressions=[exp.column(end_col, quoted=True)] - ), - partition_by=partition_cols, - order=exp.Order(expressions=order_by), - ) - - # Add distance offset if specified - if distance > 0: - lag_with_distance = exp.Add( - this=lag_window, expression=exp.Literal.number(distance) - ) - else: - lag_with_distance = lag_window - - # Build the adjacency condition (predecessor end >= current start). - adjacency = exp.GTE( - this=lag_with_distance, - expression=exp.column(start_col, quoted=True), - ) - - # An optional predicate further restricts which adjacent intervals - # are kept together: a row stays in the current cluster only when it - # is adjacent to its predecessor AND the predicate holds between them. - # ``PREV(col)`` references in the predicate resolve to the predecessor - # row via LAG over the same partition/order as the adjacency window. - predicate_expr = cluster_expr.args.get("predicate") - if predicate_expr is not None: - rewritten_predicate = self._rewrite_predecessor_refs( - predicate_expr, partition_cols, order_by - ) - keep_together = exp.And( - this=adjacency, - expression=exp.Paren(this=rewritten_predicate), - ) - else: - keep_together = adjacency - - # Create CASE expression for is_new_cluster - case_expr = exp.Case( - ifs=[ - exp.If( - this=keep_together, - true=exp.Literal.number(0), - ) - ], - default=exp.Literal.number(1), - ) - - # Build CTE SELECT expressions (all original except CLUSTER, plus is_new_cluster) - cte_expressions = [] - for expression in query.expressions: - # Skip CLUSTER expressions - if isinstance(expression, GIQLCluster): - continue - elif isinstance(expression, exp.Alias) and isinstance( - expression.this, GIQLCluster - ): - continue - else: - cte_expressions.append(expression) - - # Ensure required columns for window functions are included - required_cols = {chrom_col, start_col, end_col} - if stranded: - required_cols.add(strand_col) - - # The predicate is evaluated inside the lag_calc CTE, so every column - # it references (current-row columns and PREV() arguments alike) must - # be projected into that CTE. Folding them into required_cols makes the - # scope dependency explicit and keeps the columns available even when a - # later operator wraps this query in a further subquery. - if predicate_expr is not None: - required_cols |= {col.name for col in predicate_expr.find_all(exp.Column)} - - # Check if required columns are already in the select list - selected_cols = set() - for expr in cte_expressions: - if isinstance(expr, exp.Column): - selected_cols.add(expr.name) - elif isinstance(expr, exp.Alias): - # Don't count aliases as the source column - pass - elif isinstance(expr, exp.Star): - # SELECT * includes all columns - selected_cols = required_cols # Assume all are covered - break - - # Add missing required columns - for col in required_cols - selected_cols: - cte_expressions.append(exp.column(col, quoted=True)) - - # Add is_new_cluster calculation - cte_expressions.append(exp.alias_(case_expr, "is_new_cluster", quoted=False)) - - # Build CTE query - cte_select = exp.Select() - cte_select.select(*cte_expressions, copy=False) - - # Copy FROM, WHERE, GROUP BY, HAVING from original (but not ORDER BY) - # Use copy() to avoid sharing references between queries - if query.args.get("from_"): - from_clause = query.args["from_"].copy() - cte_select.set("from_", from_clause) - if query.args.get("where"): - cte_select.set("where", query.args["where"].copy()) - if query.args.get("group"): - cte_select.set("group", query.args["group"].copy()) - if query.args.get("having"): - cte_select.set("having", query.args["having"].copy()) - - # Create outer query with SUM over is_new_cluster - sum_window = exp.Window( - this=exp.Sum(this=exp.column("is_new_cluster")), - partition_by=partition_cols, - order=exp.Order(expressions=order_by), - ) - - # Build outer SELECT expressions (replace CLUSTER with SUM) - new_expressions = [] - for expression in query.expressions: - if isinstance(expression, GIQLCluster): - new_expressions.append(sum_window) - elif isinstance(expression, exp.Alias) and isinstance( - expression.this, GIQLCluster - ): - # Keep the alias but replace the expression - new_expressions.append( - exp.alias_(sum_window, expression.alias, quoted=False) - ) - else: - new_expressions.append(expression) - - # Build new query - new_query = exp.Select() - new_query.select(*new_expressions, copy=False) - - # Wrap CTE in subquery and set as FROM clause - subquery = exp.Subquery( - this=cte_select, - alias=exp.TableAlias(this=exp.Identifier(this="lag_calc")), - ) - new_query.from_(subquery, copy=False) - - # Copy ORDER BY from original to outer query - if query.args.get("order"): - new_query.order_by(*query.args["order"].expressions, copy=False) - - return new_query - - def _rewrite_predecessor_refs( - self, - predicate: exp.Expression, - partition_cols: list[exp.Expression], - order_by: list[exp.Ordered], - ) -> exp.Expression: - """Rewrite ``PREV(col)`` calls in a predicate to LAG windows. - - Bare column references in the predicate denote the current interval. - Each ``PREV(col)`` call denotes the sorted predecessor's value of that - column and is rewritten to ``LAG(col) OVER (...)`` using the same - partition/order as the cluster's adjacency window, so the predicate is - evaluated pairwise against the immediately preceding row. Every column - identifier (current-row columns and LAG arguments alike) is quoted so - that reserved-word genomic columns such as ``start`` / ``end`` are - emitted as valid SQL, matching how the rest of this transformer quotes - genomic columns. - - :param predicate: - Boolean predicate expression to rewrite (not mutated). - :param partition_cols: - Window partition columns (chromosome, optionally strand). - :param order_by: - Window ORDER BY terms (start position). - :return: - A copy of the predicate with every ``PREV(...)`` call replaced by an - equivalent LAG window and all column identifiers quoted. - :raises ValueError: - If a ``PREV()`` call does not take exactly one argument, or if a - ``PREV()`` call is nested inside another (predicates compare only - the immediate predecessor). - """ - - def _is_prev(node: exp.Expression) -> bool: - return isinstance(node, exp.Anonymous) and node.name.upper() == "PREV" - - def _replace(node: exp.Expression) -> exp.Expression: - if _is_prev(node): - args = node.expressions - if len(args) != 1: - raise ValueError( - f"PREV() takes exactly one column argument; got {len(args)}." - ) - if any(_is_prev(inner) for inner in args[0].find_all(exp.Anonymous)): - raise ValueError( - "PREV() cannot be nested; a CLUSTER/MERGE predicate " - "compares only the immediate predecessor." - ) - return exp.Window( - this=exp.Anonymous(this="LAG", expressions=[args[0].copy()]), - partition_by=[col.copy() for col in partition_cols], - order=exp.Order(expressions=[term.copy() for term in order_by]), - ) - return node - - rewritten = predicate.copy().transform(_replace) - for column in rewritten.find_all(exp.Column): - column.this.set("quoted", True) - return rewritten - - -class MergeTransformer: - """Transforms queries containing MERGE into GROUP BY queries. - - MERGE combines overlapping intervals using CLUSTER + aggregation: - - SELECT MERGE(interval) FROM features - - Into: - - WITH clustered AS ( - SELECT *, CLUSTER(interval) AS __giql_cluster_id FROM features - ) - SELECT - chromosome, - MIN(start_pos) AS start_pos, - MAX(end_pos) AS end_pos - FROM clustered - GROUP BY chromosome, __giql_cluster_id - ORDER BY chromosome, start_pos - """ - - def __init__(self, tables: Tables): - """Initialize transformer. - - :param tables: - Table configurations for column mapping - """ - self.tables = tables - self.cluster_transformer = ClusterTransformer(tables) - - def transform(self, query: exp.Expression) -> exp.Expression: - """Transform query if it contains MERGE expressions. - - :param query: - Parsed query AST - :return: - Transformed query AST - """ - if not isinstance(query, exp.Select): - return query - - # First, recursively transform any CTEs that might contain MERGE - if query.args.get("with_"): - cte = query.args["with_"] - for cte_expr in cte.expressions: - if isinstance(cte_expr, exp.CTE): - # Transform the CTE's subquery - cte_expr.set("this", self.transform(cte_expr.this)) - - # Recursively transform subqueries in FROM clause - if query.args.get("from_"): - from_clause = query.args["from_"] - self._transform_subqueries_in_node(from_clause) - - # Recursively transform subqueries in JOIN clauses - if query.args.get("joins"): - for join in query.args["joins"]: - self._transform_subqueries_in_node(join) - - # Recursively transform subqueries in WHERE clause - if query.args.get("where"): - self._transform_subqueries_in_node(query.args["where"]) - - # Find all MERGE expressions in the SELECT clause - merge_exprs = self._find_merge_expressions(query) - - if not merge_exprs: - return query - - # For now, support only one MERGE expression - if len(merge_exprs) > 1: - raise ValueError("Multiple MERGE expressions not yet supported") - - merge_expr = merge_exprs[0] - return self._transform_for_merge(query, merge_expr) - - def _transform_subqueries_in_node(self, node: exp.Expression): - """Recursively transform subqueries within an expression node. - - :param node: - Expression node to search for subqueries - """ - # Find and transform any Subquery nodes - for subquery in node.find_all(exp.Subquery): - if isinstance(subquery.this, exp.Select): - transformed = self.transform(subquery.this) - subquery.set("this", transformed) - - def _find_merge_expressions(self, query: exp.Select) -> list[GIQLMerge]: - """Find all MERGE expressions in query. - - :param query: - Query to search - :return: - List of MERGE expressions - """ - merge_exprs = [] - - for expression in query.expressions: - if isinstance(expression, GIQLMerge): - merge_exprs.append(expression) - elif isinstance(expression, exp.Alias): - if isinstance(expression.this, GIQLMerge): - merge_exprs.append(expression.this) - - return merge_exprs - - def _transform_for_merge( - self, query: exp.Select, merge_expr: GIQLMerge - ) -> exp.Select: - """Transform query to compute MERGE using CLUSTER + GROUP BY. - - :param query: - Original query - :param merge_expr: - MERGE expression to transform - :return: - Transformed query with clustering and aggregation - """ - # Extract MERGE parameters (same as CLUSTER) - distance_expr = merge_expr.args.get("distance") - stranded_expr = merge_expr.args.get("stranded") - predicate_expr = merge_expr.args.get("predicate") - - # Get column names from table config or use defaults - ( - chrom_col, - start_col, - end_col, - strand_col, - ) = self.cluster_transformer._get_genomic_columns(query) - - # Build CLUSTER expression with same parameters - cluster_kwargs = {"this": merge_expr.this} - if distance_expr: - cluster_kwargs["distance"] = distance_expr - if stranded_expr: - cluster_kwargs["stranded"] = stranded_expr - if predicate_expr is not None: - cluster_kwargs["predicate"] = predicate_expr - - cluster_expr = GIQLCluster(**cluster_kwargs) - - # Create intermediate query with CLUSTER - # Start with original query's FROM/WHERE/etc - cluster_query = exp.Select() - cluster_query.select(exp.Star(), copy=False) - cluster_query.select( - exp.alias_(cluster_expr, "__giql_cluster_id", quoted=False), - append=True, - copy=False, - ) - - # Copy FROM, WHERE from original - # Use copy() to avoid sharing references between queries - if query.args.get("from_"): - cluster_query.set("from_", query.args["from_"].copy()) - if query.args.get("where"): - cluster_query.set("where", query.args["where"].copy()) - - # Apply CLUSTER transformation to get the CTE-based query - cluster_query = self.cluster_transformer.transform(cluster_query) - - # Build GROUP BY columns - group_by_cols = [exp.column(chrom_col)] - - # Handle stranded parameter - if stranded_expr: - if isinstance(stranded_expr, exp.Boolean): - stranded = stranded_expr.this - elif isinstance(stranded_expr, exp.Literal): - stranded = str(stranded_expr.this).upper() == "TRUE" - else: - stranded = str(stranded_expr).upper() in ("TRUE", "1", "YES") - else: - stranded = False - - if stranded: - group_by_cols.append(exp.column(strand_col, quoted=True)) - - group_by_cols.append(exp.column("__giql_cluster_id")) - - # Build SELECT expressions for merged intervals - select_exprs = [] - - # Add group-by columns (non-aggregated) - select_exprs.append(exp.column(chrom_col, quoted=True)) - if stranded: - select_exprs.append(exp.column(strand_col, quoted=True)) - - # Add merged interval bounds - select_exprs.append( - exp.alias_( - exp.Min(this=exp.column(start_col, quoted=True)), start_col, quoted=False - ) - ) - select_exprs.append( - exp.alias_( - exp.Max(this=exp.column(end_col, quoted=True)), end_col, quoted=False - ) - ) - - # Process other columns from original SELECT - for expression in query.expressions: - # Skip the MERGE expression itself - if isinstance(expression, GIQLMerge): - continue - elif isinstance(expression, exp.Alias) and isinstance( - expression.this, GIQLMerge - ): - continue - # Include other columns (they should be aggregates or in GROUP BY) - else: - select_exprs.append(expression) - - # Build final query - final_query = exp.Select() - final_query.select(*select_exprs, copy=False) - - # FROM the clustered subquery - subquery = exp.Subquery( - this=cluster_query, - alias=exp.TableAlias(this=exp.Identifier(this="clustered")), - ) - final_query.from_(subquery, copy=False) - - # Add GROUP BY - final_query.group_by(*group_by_cols, copy=False) - - # Add ORDER BY (chromosome, start) - final_query.order_by( - exp.Ordered(this=exp.column(chrom_col, quoted=True)), copy=False - ) - final_query.order_by( - exp.Ordered(this=exp.column(start_col, quoted=True)), append=True, copy=False - ) - - return final_query - - class IntersectsBinnedJoinTransformer: """Transform column-to-column INTERSECTS into binned equi-joins. diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 3ff986a..0cd934f 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -22,10 +22,8 @@ from giql.table import Table from giql.table import Tables from giql.targets import resolve_target -from giql.transformer import ClusterTransformer from giql.transformer import IntersectsBinnedJoinTransformer from giql.transformer import IntersectsDuckDBIEJoinTransformer -from giql.transformer import MergeTransformer @overload @@ -210,8 +208,6 @@ def transpile( if duckdb_sql is not None: return duckdb_sql - merge_transformer = MergeTransformer(tables_container) - cluster_transformer = ClusterTransformer(tables_container) generator = BaseGIQLGenerator(tables=tables_container) with _reraise_as_value_error("Transformation error"): @@ -219,19 +215,20 @@ def transpile( # declined the query (returned None) and fell back to the binned plan, # exactly as before. ``intersects_bin_size`` is rejected up front for # iejoin targets, so the binned transformer always sees its default there. + # + # CLUSTER and MERGE used to be rewritten here too; they are now expanded in + # pass 3 (ExpandOperators) by giql.expanders.cluster / .merge (#144). if not target_overrides_intersects: intersects_transformer = IntersectsBinnedJoinTransformer( tables_container, bin_size=intersects_bin_size, ) ast = intersects_transformer.transform(ast) - ast = merge_transformer.transform(ast) - ast = cluster_transformer.transform(ast) # Pass 1 of the normalization pipeline (epic #114): attach resolution - # metadata to every GIQL operator slot ahead of generation. DISJOIN's - # expander consumes this metadata (step 2); the remaining operators still - # use the generator's legacy resolver paths until their ports land. + # metadata to every GIQL operator slot ahead of generation. Every migrated + # operator's expander consumes this metadata in pass 3 (CLUSTER/MERGE carry an + # empty resolution, deriving their columns from the FROM table instead). with _reraise_as_value_error("Resolution error"): ast = resolve_operator_refs(ast, tables_container) diff --git a/tests/expanders/test_cluster.py b/tests/expanders/test_cluster.py new file mode 100644 index 0000000..5fae50a --- /dev/null +++ b/tests/expanders/test_cluster.py @@ -0,0 +1,370 @@ +"""Behavioral tests for the CLUSTER operator expander (#144). + +CLUSTER migrated from the pre-pass ``ClusterTransformer`` to a registered +expander (``giql.expanders.cluster``) that rewrites the enclosing SELECT in place +into the two-level ``lag_calc`` form. These tests drive the public ``transpile`` +API and pin the behaviors the migration newly exercises — the pass walk +replacing the transformer's manual recursion (nested placements), the +FROM-table column derivation, the distance/clause branches, and projection +shapes — that the legacy transpilation suites did not already cover. The +predicate byte-shape is pinned in ``tests/test_cluster_predicate_transpilation.py``. +""" + +import pytest + +from giql.table import Table +from giql.transpile import transpile + + +class TestClusterExpander: + """Expansion of CLUSTER through the operator-expander registry (#144).""" + + @pytest.mark.parametrize( + "query", + [ + "SELECT * FROM (SELECT *, CLUSTER(interval) AS cid FROM peaks) x", + "WITH c AS (SELECT *, CLUSTER(interval) AS cid FROM peaks) SELECT * FROM c", + ], + ) + def test_transpile_should_expand_cluster_nested_in_subquery_or_cte(self, query): + """Test that a CLUSTER nested below the root SELECT still expands. + + Given: + A CLUSTER inside a FROM-subquery, and a CLUSTER inside a WITH CTE. + When: + Transpiling the query. + Then: + The nested CLUSTER should expand into the two-level lag_calc form with + no leaked operator — the pass walk replaces the manual recursion the + legacy transformer performed. + """ + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "G_I_Q_L" not in sql + assert "AS lag_calc" in sql + assert "is_new_cluster" in sql + + def test_transpile_should_raise_when_multiple_cluster_in_one_select(self): + """Test that two CLUSTER expressions in one SELECT are rejected. + + Given: + A SELECT projecting two CLUSTER expressions. + When: + Transpiling the query. + Then: + It should raise ValueError naming the unsupported multiple-CLUSTER + case, rather than chaining the rewrite into non-executable SQL (a + duplicate lag_calc alias / is_new_cluster binder error). + """ + # Arrange + query = ( + "SELECT *, CLUSTER(interval) AS a, CLUSTER(interval, 100) AS b FROM peaks" + ) + + # Act & assert + with pytest.raises( + ValueError, match="Multiple CLUSTER expressions not yet supported" + ): + transpile(query, tables=["peaks"]) + + def test_transpile_should_use_custom_columns_when_table_declares_them(self): + """Test that a stranded CLUSTER honors a custom column mapping. + + Given: + A stranded CLUSTER over a Table declaring custom column names. + When: + Transpiling the query. + Then: + The window partition and order should use the custom column names, not + the canonical defaults. + """ + # Arrange + regions = Table( + "regions", chrom_col="ch", start_col="s", end_col="e", strand_col="st" + ) + query = "SELECT *, CLUSTER(interval, stranded := true) AS cid FROM regions" + + # Act + sql = transpile(query, tables=[regions]) + + # Assert + assert 'PARTITION BY "ch", "st" ORDER BY "s"' in sql + assert 'LAG("e")' in sql + assert '"chrom"' not in sql and '"start"' not in sql + + def test_transpile_should_add_distance_offset_to_lag_when_distance_positive(self): + """Test that a positive CLUSTER distance offsets the adjacency LAG. + + Given: + A CLUSTER with a positive distance. + When: + Transpiling the query. + Then: + The adjacency should add the distance to the LAG before comparing to + start. + """ + # Arrange + query = "SELECT *, CLUSTER(interval, 100) AS cid FROM peaks" + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + window = 'OVER (PARTITION BY "chrom" ORDER BY "start" NULLS LAST)' + assert f'LAG("end") {window} + 100 >= "start"' in sql + + def test_transpile_should_not_offset_lag_when_no_distance(self): + """Test that a CLUSTER without distance uses a bare adjacency. + + Given: + A CLUSTER with no distance argument. + When: + Transpiling the query. + Then: + The adjacency should compare the bare LAG to start with no offset. + """ + # Arrange + query = "SELECT *, CLUSTER(interval) AS cid FROM peaks" + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + window = 'OVER (PARTITION BY "chrom" ORDER BY "start" NULLS LAST)' + assert f'LAG("end") {window} >= "start"' in sql + assert f'{window} + ' not in sql + + def test_transpile_should_split_clauses_between_lag_calc_and_outer_query(self): + """Test that CLUSTER places clauses at the correct query level. + + Given: + A CLUSTER query carrying WHERE, GROUP BY, HAVING, and ORDER BY. + When: + Transpiling the query. + Then: + WHERE/GROUP BY/HAVING should land inside the inner lag_calc subquery + and ORDER BY should attach to the outer query. + """ + # Arrange + query = ( + "SELECT chrom, CLUSTER(interval) AS cid FROM peaks " + "WHERE chrom = 'chr1' GROUP BY chrom HAVING COUNT(*) > 1 ORDER BY chrom" + ) + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert ( + "WHERE chrom = 'chr1' GROUP BY chrom HAVING COUNT(*) > 1) AS lag_calc" in sql + ) + assert ") AS lag_calc ORDER BY chrom" in sql + + def test_transpile_should_expand_bare_cluster_without_alias(self): + """Test that an un-aliased CLUSTER expands to a bare SUM window. + + Given: + A CLUSTER projected without an AS alias. + When: + Transpiling the query. + Then: + The bare CLUSTER should be replaced by an un-aliased SUM window with no + leaked operator. + """ + # Arrange + query = "SELECT *, CLUSTER(interval) FROM peaks" + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "G_I_Q_L" not in sql + assert "SUM(is_new_cluster) OVER" in sql + + def test_transpile_should_keep_explicit_projection_columns_in_lag_calc(self): + """Test that explicit (non-star) projection columns flow into lag_calc. + + Given: + A CLUSTER query with an explicit column projection (not SELECT *). + When: + Transpiling the query. + Then: + The explicit columns should be projected by the inner lag_calc + subquery feeding the cluster window. + """ + # Arrange + query = "SELECT chrom, start, CLUSTER(interval) AS c FROM peaks" + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "SELECT chrom, start," in sql + assert "AS lag_calc" in sql + assert "G_I_Q_L" not in sql + + # Note: CLUSTER combined with an INTERSECTS *join* in the same SELECT is not + # tested here. The clause copy into lag_calc carries only FROM/WHERE/GROUP/ + # HAVING (never JOINs), so a join is dropped — pre-existing legacy behavior + # that #144 preserves byte-for-byte, not a migration concern. + + def test_transpile_should_compose_distance_stranded_and_predicate(self): + """Test that CLUSTER composes distance, stranded, and predicate together. + + Given: + A CLUSTER with a distance, stranded mode, and a PREV-based predicate. + When: + Transpiling the query. + Then: + The output should offset the LAG by the distance, partition by chrom + and strand, and rewrite PREV into a LAG window inside the adjacency. + """ + # Arrange + query = ( + "SELECT *, CLUSTER(interval, 1000, stranded := true, " + "predicate := name = PREV(name)) AS cid FROM peaks" + ) + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "+ 1000 >= " in sql + assert 'PARTITION BY "chrom", "strand"' in sql + assert 'LAG("name")' in sql + assert "G_I_Q_L" not in sql + + def test_transpile_should_expand_cluster_in_union_branch(self): + """Test that a CLUSTER inside a UNION branch is expanded. + + Given: + A CLUSTER in each branch of a UNION (a shape the legacy transformer's + manual recursion did not descend into, leaking an unexpanded operator). + When: + Transpiling the query. + Then: + Both branches should expand to the lag_calc form with no leaked + operator — the pass walk reaches UNION branches. + """ + # Arrange + query = ( + "SELECT *, CLUSTER(interval) AS c FROM peaks " + "UNION ALL SELECT *, CLUSTER(interval) AS c FROM peaks" + ) + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "G_I_Q_L" not in sql + assert sql.count("AS lag_calc") == 2 + + @pytest.mark.parametrize("predicate_op", ["INTERSECTS", "CONTAINS", "WITHIN"]) + @pytest.mark.parametrize( + "projection", + ["*, CLUSTER(interval) AS cid", "CLUSTER(interval)"], + ids=["aliased", "bare"], + ) + def test_transpile_should_expand_spatial_predicate_copied_into_lag_calc( + self, projection, predicate_op + ): + """Test that a spatial WHERE predicate survives the CLUSTER rewrite. + + Given: + A CLUSTER query (aliased or bare) whose WHERE filters on a spatial + predicate, which the rewrite copies into the inner lag_calc subquery. + When: + Transpiling the query. + Then: + The copied predicate should itself be expanded — no leaked, unexpanded + operator — for both projection depths, pinning the #144 B1 regression + where the aliased CLUSTER expanded before the predicate and stranded a + live, unexpanded copy in the subquery. + """ + # Arrange + query = ( + f"SELECT {projection} FROM peaks a " + f"WHERE a.interval {predicate_op} 'chr1:1-100'" + ) + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "G_I_Q_L" not in sql + assert "AS lag_calc" in sql + + @pytest.mark.parametrize( + "query", + [ + "SELECT ABS(CLUSTER(interval)) FROM peaks", + "SELECT CLUSTER(interval) + 1 AS c FROM peaks", + ], + ) + def test_transpile_should_raise_when_cluster_nested_in_projection_expression( + self, query + ): + """Test that a CLUSTER buried in a projection expression is rejected. + + Given: + A CLUSTER nested inside a larger projection expression (a function call + or arithmetic), which has no coherent whole-query rewrite. + When: + Transpiling the query. + Then: + It should raise ValueError requiring a top-level projection item, + rather than leaking an unexpanded operator to the generator. + """ + # Act & assert + with pytest.raises(ValueError, match="must be a top-level projection item"): + transpile(query, tables=["peaks"]) + + def test_transpile_should_inject_lag_calc_columns_deterministically(self): + """Test that the injected lag_calc column order is hash-seed independent. + + Given: + An explicit-projection CLUSTER whose predicate forces several residual + columns to be injected into lag_calc, transpiled in two child + interpreters under differing PYTHONHASHSEED values. + When: + Comparing the two emitted strings. + Then: + They should be byte-identical, proving the injected-column order is + sorted rather than set-iteration (PYTHONHASHSEED) dependent (#144 A2). + """ + # Arrange + import os + import subprocess + import sys + + code = ( + "from giql.transpile import transpile; " + "print(transpile(" + "\"SELECT chrom, CLUSTER(interval, stranded := true, " + "predicate := name = PREV(score)) AS c FROM peaks\", tables=['peaks']))" + ) + base_env = { + k: v + for k, v in os.environ.items() + if not k.startswith("COV_CORE") and k != "COVERAGE_PROCESS_START" + } + + def _run(seed: str) -> str: + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + env={**base_env, "PYTHONHASHSEED": seed}, + ) + assert result.returncode == 0, result.stderr + return result.stdout + + # Act + out_a = _run("0") + out_b = _run("1") + + # Assert + assert out_a == out_b + assert "G_I_Q_L" not in out_a diff --git a/tests/expanders/test_merge.py b/tests/expanders/test_merge.py new file mode 100644 index 0000000..fcdc9d7 --- /dev/null +++ b/tests/expanders/test_merge.py @@ -0,0 +1,292 @@ +"""Behavioral tests for the MERGE operator expander (#144). + +MERGE migrated from the pre-pass ``MergeTransformer`` to a registered expander +(``giql.expanders.merge``) that rewrites the enclosing SELECT in place into the +clustered-aggregation form. These tests drive the public ``transpile`` API and +pin the behaviors the migration newly exercises (the registry pass, the +recursion-removal, and the error/limitation guards) that the legacy +transpilation suites did not already cover. The predicate byte-shape is pinned +separately in ``tests/test_cluster_predicate_transpilation.py``. +""" + +import pytest + +from giql.table import Table +from giql.transpile import transpile + + +class TestMergeExpander: + """Expansion of MERGE through the operator-expander registry (#144).""" + + def test_transpile_should_raise_when_multiple_merge_in_one_select(self): + """Test that two MERGE expressions in one SELECT are rejected. + + Given: + A SELECT projecting two MERGE expressions. + When: + Transpiling the query. + Then: + It should raise ValueError naming the unsupported multiple-MERGE case. + """ + # Arrange + query = "SELECT MERGE(interval), MERGE(interval, 100) FROM peaks" + + # Act & assert + with pytest.raises( + ValueError, match="Multiple MERGE expressions not yet supported" + ): + transpile(query, tables=["peaks"]) + + @pytest.mark.parametrize( + "query", + [ + "SELECT MERGE(interval), CLUSTER(interval) AS cid FROM peaks", + "SELECT CLUSTER(interval) AS cid, MERGE(interval) FROM peaks", + ], + ) + def test_transpile_should_raise_when_cluster_and_merge_share_select(self, query): + """Test that combining CLUSTER and MERGE in one SELECT is rejected. + + Given: + A SELECT projecting both a MERGE and a CLUSTER (either order). + When: + Transpiling the query. + Then: + It should raise ValueError naming the unsupported combination, rather + than silently emitting SQL with a leaked, unexpanded operator. + """ + # Act & assert + with pytest.raises(ValueError, match="CLUSTER and MERGE cannot be combined"): + transpile(query, tables=["peaks"]) + + def test_transpile_should_group_by_strand_when_merge_stranded(self): + """Test that a stranded MERGE aggregates within strand. + + Given: + A stranded MERGE query. + When: + Transpiling the query. + Then: + It should aggregate MIN(start)/MAX(end), group by chrom, strand, and + the synthesized cluster id, project strand, and partition the inner + windows by chrom and strand. + """ + # Arrange + query = "SELECT MERGE(interval, stranded := true) FROM peaks" + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert 'MIN("start") AS start' in sql + assert 'MAX("end") AS end' in sql + assert 'GROUP BY "chrom", "strand", __giql_cluster_id' in sql + assert 'PARTITION BY "chrom", "strand"' in sql + assert "G_I_Q_L" not in sql + + def test_transpile_should_use_custom_columns_when_table_declares_them(self): + """Test that a stranded MERGE honors a custom column mapping. + + Given: + A stranded MERGE over a Table declaring custom chrom/start/end/strand + column names. + When: + Transpiling the query. + Then: + The aggregation, GROUP BY, and window partitions should use the custom + column names, never the canonical defaults. + """ + # Arrange + regions = Table( + "regions", chrom_col="ch", start_col="s", end_col="e", strand_col="st" + ) + query = "SELECT MERGE(interval, stranded := true) FROM regions" + + # Act + sql = transpile(query, tables=[regions]) + + # Assert + assert 'MIN("s") AS s' in sql + assert 'MAX("e") AS e' in sql + assert 'GROUP BY "ch", "st", __giql_cluster_id' in sql + assert 'PARTITION BY "ch", "st"' in sql + assert '"chrom"' not in sql and '"start"' not in sql + + @pytest.mark.parametrize( + "query", + [ + "SELECT * FROM (SELECT MERGE(interval) FROM peaks) x", + "WITH c AS (SELECT MERGE(interval) FROM peaks) SELECT * FROM c", + ], + ) + def test_transpile_should_expand_merge_nested_in_subquery_or_cte(self, query): + """Test that a MERGE nested below the root SELECT still expands. + + Given: + A MERGE inside a FROM-subquery, and a MERGE inside a WITH CTE. + When: + Transpiling the query. + Then: + The nested MERGE should expand into the clustered-aggregation form + with no leaked operator — the pass walk replaces the manual recursion + the legacy transformer performed. + """ + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "G_I_Q_L" not in sql + assert "AS clustered" in sql + assert "__giql_cluster_id" in sql + + def test_transpile_should_carry_where_into_clustered_subquery(self): + """Test that a WHERE clause is pushed into MERGE's clustered subquery. + + Given: + A MERGE query with a WHERE clause. + When: + Transpiling the query. + Then: + The WHERE predicate should appear inside the inner clustered subquery + that feeds the aggregation. + """ + # Arrange + query = "SELECT MERGE(interval) AS m FROM peaks WHERE chrom = 'chr1'" + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "WHERE chrom = 'chr1'" in sql + assert "AS clustered" in sql + + @pytest.mark.parametrize( + "predicate, message", + [ + ("depth = PREV(depth, score)", "exactly one column argument"), + ("depth = PREV(PREV(depth))", "cannot be nested"), + ], + ) + def test_transpile_should_validate_prev_in_merge_predicate(self, predicate, message): + """Test that MERGE reuses CLUSTER's PREV validation. + + Given: + A MERGE predicate calling PREV with the wrong arity or a nested PREV. + When: + Transpiling the query. + Then: + It should raise the matching ValueError, proving the MERGE path + reaches the shared predecessor-reference validation. + """ + # Arrange + query = f"SELECT MERGE(interval, predicate := {predicate}) FROM peaks" + + # Act & assert + with pytest.raises(ValueError, match=message): + transpile(query, tables=["peaks"]) + + def test_transpile_should_expand_merge_in_projection_scalar_subquery(self): + """Test that a MERGE in a projection scalar-subquery is expanded. + + Given: + A MERGE inside a scalar subquery in the SELECT list (a shape the legacy + transformer's manual recursion did not descend into, leaking an + unexpanded operator). + When: + Transpiling the query. + Then: + The MERGE should expand into the clustered-aggregation form with no + leaked operator — the pass walk reaches projection subqueries. + """ + # Arrange + query = "SELECT (SELECT MERGE(interval) FROM peaks) AS m FROM other" + + # Act + sql = transpile(query, tables=["peaks", "other"]) + + # Assert + assert "G_I_Q_L" not in sql + assert "AS clustered" in sql + assert "__giql_cluster_id" in sql + + @pytest.mark.parametrize("predicate_op", ["INTERSECTS", "CONTAINS", "WITHIN"]) + @pytest.mark.parametrize( + "projection", + ["*, MERGE(interval) AS m", "MERGE(interval)"], + ids=["aliased", "bare"], + ) + def test_transpile_should_expand_spatial_predicate_copied_into_clustered( + self, projection, predicate_op + ): + """Test that a spatial WHERE predicate survives the MERGE rewrite. + + Given: + A MERGE query (aliased or bare) whose WHERE filters on a spatial + predicate, which the rewrite copies into the inner clustered subquery. + When: + Transpiling the query. + Then: + The copied predicate should itself be expanded — no leaked, unexpanded + operator — for both projection depths, pinning the #144 B1 regression + where the aliased MERGE expanded before the predicate and stranded a + live, unexpanded copy in the subquery. + """ + # Arrange + query = ( + f"SELECT {projection} FROM peaks a " + f"WHERE a.interval {predicate_op} 'chr1:1-100'" + ) + + # Act + sql = transpile(query, tables=["peaks"]) + + # Assert + assert "G_I_Q_L" not in sql + assert "AS clustered" in sql + + @pytest.mark.parametrize( + "query", + [ + "SELECT ABS(MERGE(interval)) FROM peaks", + "SELECT MERGE(interval) + 1 AS m FROM peaks", + ], + ) + def test_transpile_should_raise_when_merge_nested_in_projection_expression( + self, query + ): + """Test that a MERGE buried in a projection expression is rejected. + + Given: + A MERGE nested inside a larger projection expression (a function call or + arithmetic), which has no coherent whole-query rewrite. + When: + Transpiling the query. + Then: + It should raise ValueError requiring a top-level projection item, + rather than leaking an unexpanded operator to the generator. + """ + # Act & assert + with pytest.raises(ValueError, match="must be a top-level projection item"): + transpile(query, tables=["peaks"]) + + def test_transpile_should_quote_group_by_chrom_when_chrom_is_reserved_word(self): + """Test that the MERGE GROUP BY chrom term is quoted. + + Given: + A MERGE over a Table whose chrom column is a SQL reserved word. + When: + Transpiling the query. + Then: + The GROUP BY chrom term should be quoted (like every other chrom + reference), so the reserved-word column emits valid SQL (#144 A13). + """ + # Arrange + regions = Table("regions", chrom_col="order", start_col="s", end_col="e") + query = "SELECT MERGE(interval) FROM regions" + + # Act + sql = transpile(query, tables=[regions]) + + # Assert + assert 'GROUP BY "order"' in sql + assert "GROUP BY order" not in sql diff --git a/tests/test_expander.py b/tests/test_expander.py index ab93732..7844552 100644 --- a/tests/test_expander.py +++ b/tests/test_expander.py @@ -1172,44 +1172,38 @@ def test_expand_operators_is_identity_when_registry_empty(self): assert result is ast assert list(result.find_all(GIQLDisjoin)) - def test_transpile_sql_unchanged_with_pass_inert(self): - """Test that transpile output is byte-identical for an unmigrated operator. + def test_expand_operators_skips_opted_out_operator_with_default_registry(self): + """Test that the pass skips an opted-out operator even with its expander live. Given: - A CLUSTER query (an operator not migrated onto the pass in any wave-3 - branch, so its GIQL_EXPAND is False and no expander resolves), with - the default registry. + A migrated operator query and the import-populated default REGISTRY (so + its expander resolves), but the operator opted out of GIQL_EXPAND for + the test. When: - Running the ExpandOperators pass (default REGISTRY) over the resolved - AST and serializing both the original and the pass-run AST. + Running the ExpandOperators pass over the default REGISTRY. Then: - The pass leaves the operator node in place, the serialized SQL is - byte-identical, and no expander alias prefix appears — the pass is - inert for any operator that has not been migrated. + It should leave the operator node in place and return the same tree — + the per-type GIQL_EXPAND gate holds even when an expander is registered + (after #144 no operator ships unmigrated, so the gate is exercised via + an opt-out rather than a shipped False). """ # Arrange - query = "SELECT *, CLUSTER(interval) AS cluster_id FROM peaks" - tables = _tables(("peaks",)) - ast = _prepare(query, tables) - from giql.generators import BaseGIQLGenerator - - before = BaseGIQLGenerator(tables=tables).generate(ast) - before_ops = len(list(ast.find_all(GIQLCluster))) + operator = _A_MIGRATED_OPERATOR + ast, tables = _prepare_operator(operator) + before_ops = len(list(ast.find_all(operator))) # Act — the wired-in pass over the default REGISTRY must be a no-op here. - result = expand_operators(ast, GenericTarget(), tables) - after = BaseGIQLGenerator(tables=tables).generate(result) + with _opted_out(operator): + result = expand_operators(ast, GenericTarget(), tables) # Assert - assert after == before - assert len(list(result.find_all(GIQLCluster))) == before_ops - assert EXPAND_ALIAS_PREFIX not in after + assert result is ast + assert len(list(result.find_all(operator))) == before_ops # The nine GIQL operator expression classes the ExpandOperators pass inspects. -# Each migrated operator ships opted in (GIQL_EXPAND=True) alongside its -# registered expander; the rest ship opted out (False) and fall through to the -# legacy emitter. +# Every operator is now migrated (#144): each ships opted in (GIQL_EXPAND=True) +# alongside its registered expander, so none falls through to a legacy emitter. from giql.expressions import Contains # noqa: E402 from giql.expressions import GIQLCluster # noqa: E402 from giql.expressions import GIQLDistance # noqa: E402 @@ -1252,11 +1246,15 @@ def test_transpile_sql_unchanged_with_pass_inert(self): assert _MIGRATED_OPERATORS, "expected at least one migrated operator" #: An arbitrary migrated operator the operator-agnostic control tests target. _A_MIGRATED_OPERATOR = _MIGRATED_OPERATORS[0] -#: Operators not yet migrated — they ship GIQL_EXPAND=False. +#: Operators not yet migrated — they ship GIQL_EXPAND=False. Empty since #144 +#: migrated the last two (CLUSTER and MERGE): every operator now expands through +#: the pass. Control tests that need an operator to behave as if unmigrated drive a +#: migrated one through ``_opted_out`` rather than relying on a shipped ``False``. _UNMIGRATED_OPERATORS = tuple( op for op in _OPERATOR_CLASSES if op not in _MIGRATED_OPERATORS ) -assert _UNMIGRATED_OPERATORS, "expected at least one unmigrated operator" +#: Pin the post-#144 invariant: every GIQL operator is migrated onto the pass. +assert not _UNMIGRATED_OPERATORS, "every operator should be migrated after #144" #: A minimal GIQL query producing one node of each operator class, keyed by the @@ -1314,28 +1312,90 @@ def _prepare_operator(operator: type) -> tuple[exp.Expression, Tables]: return _prepare(query, tables), tables -class TestOperatorOptOut: - """Migrated operators opt into the pass; the rest still ship opted out.""" +class TestClusterMergeExpansion: + """CLUSTER and MERGE expand through the pass into their restructured forms (#144).""" - @pytest.mark.parametrize( - "operator", _UNMIGRATED_OPERATORS, ids=lambda c: c.__name__ - ) - def test_operator_class_ships_expand_disabled(self, operator): - """Test that each unmigrated operator class ships GIQL_EXPAND=False. + def test_transform_replaces_cluster_with_lag_calc_subquery(self): + """Test that the pass rewrites a CLUSTER query into the two-level form. Given: - A GIQL operator expression class that has not been migrated onto the - ExpandOperators pass. + A resolved ``SELECT *, CLUSTER(interval) ...`` AST and the default + REGISTRY (CLUSTER ships GIQL_EXPAND=True with a registered expander). When: - Reading its GIQL_EXPAND class attribute. + Running the ExpandOperators pass. Then: - It should be False (the operator still uses the legacy emitter). + It should consume the CLUSTER node in place (returning the same root + object), wrap the source in a ``lag_calc`` derived table with a LAG + window and an ``is_new_cluster`` CASE, project an outer SUM window, and + mint no expander alias. """ - # Arrange & act - flag = operator.GIQL_EXPAND + # Arrange + ast, tables = _prepare_operator(GIQLCluster) + + # Act + result = expand_operators(ast, GenericTarget(), tables) # Assert - assert flag is False + assert result is ast # whole-query rewrite mutates the root in place + assert not list(result.find_all(GIQLCluster)) + aliases = {sub.alias for sub in result.find_all(exp.Subquery) if sub.alias} + assert "lag_calc" in aliases + windows = list(result.find_all(exp.Window)) + assert any(isinstance(w.this, exp.Sum) for w in windows) # outer cluster id + assert any( + isinstance(w.this, exp.Anonymous) and w.this.name.upper() == "LAG" + for w in windows + ) # inner adjacency LAG + assert any( + isinstance(a, exp.Alias) and a.alias == "is_new_cluster" + for a in result.find_all(exp.Alias) + ) + assert EXPAND_ALIAS_PREFIX not in result.sql(dialect=GIQLDialect) + + def test_transform_replaces_merge_with_clustered_group_by(self): + """Test that the pass rewrites a MERGE query into the clustered-aggregation form. + + Given: + A resolved ``SELECT MERGE(interval) ...`` AST and the default REGISTRY + (MERGE ships GIQL_EXPAND=True with a registered expander). + When: + Running the ExpandOperators pass. + Then: + It should consume the MERGE node in place (returning the same root + object), wrap a ``clustered`` subquery (itself wrapping a ``lag_calc``) + under a GROUP BY that includes the synthesized ``__giql_cluster_id``, + project MIN/MAX bounds, and mint no expander alias. + """ + # Arrange + ast, tables = _prepare_operator(GIQLMerge) + + # Act + result = expand_operators(ast, GenericTarget(), tables) + + # Assert + assert result is ast # whole-query rewrite mutates the root in place + assert not list(result.find_all(GIQLMerge)) + aliases = {sub.alias for sub in result.find_all(exp.Subquery) if sub.alias} + assert "clustered" in aliases # MERGE wraps the clustered subquery + assert "lag_calc" in aliases # built on CLUSTER + group = result.find(exp.Group) + assert group is not None + assert any( + isinstance(g, exp.Column) and g.name == "__giql_cluster_id" + for g in group.expressions + ) + assert any(isinstance(m, exp.Min) for m in result.find_all(exp.Min)) + assert any(isinstance(m, exp.Max) for m in result.find_all(exp.Max)) + assert EXPAND_ALIAS_PREFIX not in result.sql(dialect=GIQLDialect) + + +class TestOperatorOptOut: + """Every operator is now migrated, so all ship GIQL_EXPAND=True. + + The complementary ``test_operator_class_ships_expand_disabled`` was dropped + when #144 migrated the last operators: ``_UNMIGRATED_OPERATORS`` is empty, so + there is no class left to assert ships ``False``. + """ @pytest.mark.parametrize( "operator", _MIGRATED_OPERATORS, ids=lambda c: c.__name__ @@ -1392,24 +1452,27 @@ def test_opted_in_restores_flag_after_exception(self): """Test that _opted_in restores GIQL_EXPAND when the body raises. Given: - An operator class at its default GIQL_EXPAND=False. + An operator driven to GIQL_EXPAND=False via _opted_out (every operator + now ships True after #144, so the restore target is set up explicitly). When: Its _opted_in body raises an exception. Then: The flag should be restored to False (the manager is exception-safe, so a raising expansion test cannot leak an opt-in into a later test). """ - # Arrange - assert GIQLMerge.GIQL_EXPAND is False + # Arrange — set the restore target to False so the restore is observable. + operator = _A_MIGRATED_OPERATOR + with _opted_out(operator): + assert operator.GIQL_EXPAND is False - # Act - with pytest.raises(RuntimeError): - with _opted_in(GIQLMerge): - assert GIQLMerge.GIQL_EXPAND is True - raise RuntimeError("boom") + # Act + with pytest.raises(RuntimeError): + with _opted_in(operator): + assert operator.GIQL_EXPAND is True + raise RuntimeError("boom") - # Assert - assert GIQLMerge.GIQL_EXPAND is False + # Assert + assert operator.GIQL_EXPAND is False class TestIEJoinRegistryDeferral: @@ -1938,22 +2001,19 @@ def test_walk_partial_opt_in_replaces_only_flagged_type(self, clean_registry): """Test that only the flagged operator type is replaced when both registered. Given: - A genuinely-unmigrated operator (GIQLCluster, shipping - GIQL_EXPAND=False in every wave-3 branch) and an INTERSECTS, both with - registered expanders, but only INTERSECTS flagged GIQL_EXPAND for the - test. + A CLUSTER and an INTERSECTS, both with registered expanders, but + CLUSTER held off (opted out of GIQL_EXPAND) while only INTERSECTS is + opted in. When: Running the pass. Then: - The INTERSECTS is replaced while the unmigrated operator node remains - on its own shipped ``False`` flag — no opt-out ceremony needed (the - gate is per-type). + The INTERSECTS is replaced while the held-off operator node remains — + the gate is per-type, so opting CLUSTER out alone holds its expansion + off even though its expander is registered. """ - # Arrange — the held-off subject is genuinely unmigrated: it survives on - # its own shipped GIQL_EXPAND=False, not on a test opt-out. + # Arrange — the held-off subject is a migrated operator driven off via + # _opted_out (after #144 no operator ships GIQL_EXPAND=False). held_off = GIQLCluster - assert held_off in _UNMIGRATED_OPERATORS - assert held_off.GIQL_EXPAND is False clean_registry.register( GenericTarget(), held_off, lambda n, c: exp.column("CL") ) @@ -1968,9 +2028,10 @@ def test_walk_partial_opt_in_replaces_only_flagged_type(self, clean_registry): ) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act — only INTERSECTS is opted in; the unmigrated operator stays off. - with _opted_in(Intersects): - result = pass_.transform(ast) + # Act — only INTERSECTS is opted in; CLUSTER is held off via opt-out. + with _opted_out(held_off): + with _opted_in(Intersects): + result = pass_.transform(ast) # Assert assert list(result.find_all(held_off)) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 5bf6138..754474e 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -238,6 +238,7 @@ def test_cluster_basic(self): assert "SELECT" in sql assert "SUM" in sql.upper() or "LAG" in sql.upper() + assert "G_I_Q_L" not in sql # no leaked, unexpanded operator def test_cluster_with_distance(self): """ @@ -255,6 +256,7 @@ def test_cluster_with_distance(self): assert "SELECT" in sql assert "100" in sql + assert "G_I_Q_L" not in sql # no leaked, unexpanded operator def test_cluster_stranded(self): """ @@ -272,6 +274,7 @@ def test_cluster_stranded(self): assert "SELECT" in sql assert "strand" in sql.lower() + assert "G_I_Q_L" not in sql # no leaked, unexpanded operator class TestTranspileMerge: @@ -292,6 +295,7 @@ def test_merge_basic(self): assert "MIN" in sql.upper() assert "MAX" in sql.upper() assert "GROUP BY" in sql.upper() + assert "G_I_Q_L" not in sql # no leaked, unexpanded operator def test_merge_with_distance(self): """ @@ -306,6 +310,7 @@ def test_merge_with_distance(self): assert "SELECT" in sql assert "100" in sql + assert "G_I_Q_L" not in sql # no leaked, unexpanded operator def test_merge_with_aggregation(self): """ @@ -320,6 +325,7 @@ def test_merge_with_aggregation(self): assert "SELECT" in sql assert "COUNT" in sql.upper() + assert "G_I_Q_L" not in sql # no leaked, unexpanded operator class TestTranspileNearest: