diff --git a/docs/dialect/distance-operators.rst b/docs/dialect/distance-operators.rst index d3ebe89..ad333d4 100644 --- a/docs/dialect/distance-operators.rst +++ b/docs/dialect/distance-operators.rst @@ -309,6 +309,15 @@ Find nearby same-strand features within distance constraints: WHERE nearest.distance BETWEEN -10000 AND 10000 ORDER BY peaks.name, ABS(nearest.distance) +Target support +~~~~~~~~~~~~~~ + +A correlated ``NEAREST`` (its reference is an outer-row column) runs on lateral-capable engines — DuckDB and the generic target — via a correlated ``LATERAL`` subquery, and on Apache DataFusion, which has no correlated-``LATERAL`` physical plan, via a decorrelated window-function rewrite. For an **explicitly-projected** query (one that selects named columns, e.g. ``SELECT a.start, b.start, b.distance``) the two forms return identical results: the ``(start, end)`` tiebreaker orders rows tied at the k-th distance the same way on every engine, deterministically whenever ``(start, end)`` distinguishes the tied candidates. A standalone ``NEAREST`` with a literal reference is uncorrelated and uses the same ordered, limited subquery on every target. + +.. note:: + + **Known limitation —** ``SELECT *`` **/** ``SELECT b.*`` **over a correlated NEAREST on DataFusion.** The decorrelated window-function rewrite needs its reference-key and rank columns (``__giql_x_rk_*``, ``__giql_x_rn``) visible on the rewritten join, so a ``SELECT *`` or ``SELECT b.*`` over a correlated NEAREST exposes those reserved internal columns on DataFusion — a different output schema than the LATERAL form emits on DuckDB. The cross-target identity claim above therefore holds for **explicitly-projected** queries only. Projecting named columns avoids the leak entirely. A query-level wrapper that projects the reserved columns away on the DataFusion path is tracked by `#160 `_ (it depends on the query-level expander seam from #146). + Notes ~~~~~ diff --git a/src/giql/expanders/nearest.py b/src/giql/expanders/nearest.py new file mode 100644 index 0000000..482c232 --- /dev/null +++ b/src/giql/expanders/nearest.py @@ -0,0 +1,474 @@ +"""The NEAREST operator expander (epic #137, issue #142). + +NEAREST is the first operator whose expansion is genuinely capability-driven. +The portable form is a correlated ``LATERAL`` subquery: each outer row drives a +``SELECT ... FROM WHERE ORDER BY ABS(distance) +LIMIT k`` whose reference endpoints are outer-table columns. DuckDB and the +generic target plan that directly (``supports_lateral == True``). + +Apache DataFusion has no correlated-``LATERAL`` physical plan +(``supports_lateral == False``). For it the same k-nearest / ``max_distance`` / +``stranded`` / ``signed`` semantics are reproduced with a **decorrelated +window-function fallback**: the target is cross-joined against the outer +relation, each candidate is ranked with +``ROW_NUMBER() OVER (PARTITION BY ORDER BY ABS(distance))``, and +the surrounding ``CROSS JOIN LATERAL`` is rewritten into a plain join that +re-associates the top-``k`` ranked candidates back to every outer row sharing +that reference key. Ranking depends only on the reference value, so ranking once +per distinct reference value and re-joining is set-equivalent to the per-row +LATERAL form — deterministic when the ``(start, end)`` tiebreaker distinguishes +candidates tied at the k-th distance (verified by the cross-target result +oracle). + +A literal-reference (standalone) NEAREST is already an uncorrelated subquery, so +every target — DataFusion included — uses the LATERAL/standalone form unchanged; +only the *correlated* shape needs the fallback. + +The expander reuses :class:`giql.generators.base.BaseGIQLGenerator`'s +``_generate_distance_case`` (shared with DISTANCE, #140) and ``_nearest_*`` +passthrough/diagnostic helpers — all static, so they are called on the class with +no generator instance — then parses the assembled SQL fragments into AST so the +emitted SQL is reserialized by the active target's serializer. +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expander import EXPAND_ALIAS_PREFIX +from giql.expander import ExpansionContext +from giql.expander import register +from giql.expressions import GIQLNearest +from giql.generators.base import BaseGIQLGenerator +from giql.resolver import ResolvedInterval +from giql.resolver import ResolvedRef +from giql.targets import GenericTarget + +#: Reserved column names the window-function fallback synthesizes inside its +#: ranked subquery. They are derived from the expander's reserved +#: ``EXPAND_ALIAS_PREFIX`` (``__giql_x_``) — rather than hardcoded — so they stay +#: clear of user identifiers and track the prefix if it ever changes, mirroring +#: the other reserved internal prefixes. +_RANK_COL = f"{EXPAND_ALIAS_PREFIX}rn" +_REF_KEY_PREFIX = f"{EXPAND_ALIAS_PREFIX}rk_" + + +def _nearest_params( + expression: GIQLNearest, +) -> tuple[int, int | None, bool, bool]: + """Unpack the (k, max_distance, stranded, signed) parameters of a NEAREST.""" + k = expression.args.get("k") + k_value = int(str(k)) if k else 1 + + max_distance = expression.args.get("max_distance") + max_dist_value = int(str(max_distance)) if max_distance else None + + is_stranded = BaseGIQLGenerator._extract_bool_param(expression.args.get("stranded")) + is_signed = BaseGIQLGenerator._extract_bool_param(expression.args.get("signed")) + return k_value, max_dist_value, is_stranded, is_signed + + +def _distance_and_filters( + expression: GIQLNearest, + table_name: str, + target_ref: ResolvedRef, + ref: ResolvedInterval, + ref_fragments: tuple[str, str, str, str | None] | None = None, +) -> tuple[str, str, list[str], str]: + """Build the shared distance SQL, the qualified target columns, and WHERE. + + Returns ``(distance_expr, abs_distance_expr, where_clauses, passthrough)`` — + the fragments common to the LATERAL/standalone form and the decorrelated + fallback. Distance math, the chromosome pre-filter, the optional strand match, + and the optional ``max_distance`` filter all reproduce the legacy + ``giqlnearest_sql`` emitter exactly. Each form derives its deterministic + ORDER BY tiebreaker from the target columns itself. + + ``ref_fragments`` optionally overrides the reference ``(chrom, start, end, + strand)`` SQL fragments. The LATERAL form consumes the resolution's + outer-qualified fragments verbatim; the fallback passes fragments pointing at + its renamed, pre-projected reference relation so the cross-joined columns + carry names distinct from the target's (DataFusion's planner cannot resolve a + window ordering over a join with duplicate column names). + """ + target_chrom, target_start, target_end = target_ref.cols + _k_value, max_dist_value, is_stranded, is_signed = _nearest_params(expression) + + output_table = BaseGIQLGenerator._nearest_output_encoding(expression, target_ref) + passthrough = BaseGIQLGenerator._nearest_passthrough( + table_name, target_start, target_end, output_table + ) + + if ref_fragments is not None: + ref_chrom, ref_start, ref_end, ref_strand_frag = ref_fragments + else: + ref_chrom, ref_start, ref_end, ref_strand_frag = ( + ref.chrom, + ref.start, + ref.end, + ref.strand, + ) + + ref_strand = None + target_strand = None + if is_stranded: + ref_strand = ref_strand_frag + if output_table and output_table.strand_col: + target_strand = f'{table_name}."{output_table.strand_col}"' + + target_chrom_expr = f'{table_name}."{target_chrom}"' + target_start_expr = f'{table_name}."{target_start}"' + target_end_expr = f'{table_name}."{target_end}"' + + distance_expr = BaseGIQLGenerator._generate_distance_case( + ref_chrom, + ref_start, + ref_end, + ref_strand, + target_chrom_expr, + target_start_expr, + target_end_expr, + target_strand, + stranded=is_stranded, + signed=is_signed, + ) + abs_distance_expr = f"ABS({distance_expr})" + + where_clauses = [f"{ref_chrom} = {target_chrom_expr}"] + if is_stranded and ref_strand and target_strand: + where_clauses.append(f"{ref_strand} = {target_strand}") + if max_dist_value is not None: + where_clauses.append(f"({abs_distance_expr}) <= {max_dist_value}") + + return distance_expr, abs_distance_expr, where_clauses, passthrough + + +def _lateral_form( + expression: GIQLNearest, + ctx: ExpansionContext, + table_name: str, + target_ref: ResolvedRef, + ref: ResolvedInterval, +) -> exp.Expression: + """The portable LATERAL/standalone subquery. + + Builds a two-level subquery: an inner ``SELECT , AS + distance FROM WHERE ...`` that materializes the distance, wrapped by + an outer ``SELECT * FROM () AS x ORDER BY ABS(x.distance), x., + x. LIMIT k`` that orders on the *precomputed* ``distance`` column. For a + correlated placement the parent ``LATERAL`` correlates it to the outer row; + for a standalone (literal-reference) placement it stands alone. + + Splitting the distance computation (inner) from the ordering (outer) is + load-bearing for cross-engine support: + + * DuckDB's correlated-``LATERAL`` binder will not resolve a SELECT-list alias + named ``distance`` from inside an ``ORDER BY`` that also projects + ``.*``, so the order key must reference a *materialized* column + (``x.distance``) from the wrapping level rather than an alias in the same + SELECT. + * DataFusion's planner, given the distance ``CASE`` re-emitted inline in the + ``ORDER BY`` over the chromosome-equality prefiltered scan, rewrites the + filtered ``chrom`` to a self-comparison in one copy of the CASE but not the + other and trips ``SanityCheckPlan``; ordering on the materialized column + avoids re-deriving the key. + + A ``(start, end)`` tiebreaker follows ``ABS(distance)`` so rows tied at the + k-th distance order deterministically — and identically across engines and + against the decorrelated fallback's ranking — *when ``(start, end)`` + distinguishes the tied candidates*. Two target rows sharing both distance and + ``(start, end)`` remain order-ambiguous (no key here breaks that residual + tie); the two forms are set-equivalent up to such coordinate-duplicate ties + (#142 A5). + """ + k_value, *_ = _nearest_params(expression) + ( + distance_expr, + _abs_distance_expr, + where_clauses, + passthrough, + ) = _distance_and_filters(expression, table_name, target_ref, ref) + where_sql = " AND ".join(where_clauses) + # The wrapping level reads the inner row's *bare* column names (the passthrough + # projected ``.*``), so the tiebreaker qualifies them by the wrapper + # alias, not the original ``table_name."col"``. + _chrom, target_start_col, target_end_col = target_ref.cols + wrapper = ctx.alias() + inner = ( + f"SELECT {passthrough}, {distance_expr} AS distance " + f"FROM {table_name} WHERE {where_sql}" + ) + sql = ( + f"(SELECT * FROM ({inner}) AS {wrapper} " + f'ORDER BY ABS({wrapper}."distance"), ' + f'{wrapper}."{target_start_col}", {wrapper}."{target_end_col}" ' + f"LIMIT {k_value})" + ) + return parse_one(sql, dialect=GIQLDialect) + + +def _outer_relation(ref: ResolvedInterval) -> tuple[str, str]: + """Return ``(physical_relation, alias)`` for the correlated reference table. + + The reference endpoints are alias-qualified fragments (``a."chrom"``). The + alias is the outer table's correlation name in the query; the physical + relation comes from the reference's backing :class:`~giql.table.Table`. Both + are needed to re-introduce the outer relation inside the decorrelated + subquery the fallback builds. + + The ``else alias`` branch (``ref.table is None``) only fires for a reference + whose backing table the resolver could not attach. For a *correlated* + NEAREST — the only shape that reaches the fallback — the reference is an + outer-row column, so the resolver always attaches its table and this branch + is not reached in practice. Falling back to ``alias`` (yielding a + cosmetically redundant ``FROM AS ``, where ``relation == + alias``) keeps the emitted SQL valid rather than emitting an empty relation + name should that invariant ever not hold; the assert pins the expectation. + """ + parsed = parse_one(ref.chrom, dialect=GIQLDialect) + alias = parsed.table if isinstance(parsed, exp.Column) else "" + if ref.table is not None: + relation = ref.table.name + else: + # Defensive fallback only: a correlated reference always carries a + # resolved backing table, so relation == alias here would be a redundant + # self-alias rather than a real two-name relation. + assert alias, ( + "correlated NEAREST fallback expected the reference to carry either a " + "backing table or an alias-qualified column" + ) + relation = alias + return relation, alias + + +def _fallback_form( + expression: GIQLNearest, + ctx: ExpansionContext, + table_name: str, + target_ref: ResolvedRef, + ref: ResolvedInterval, +) -> exp.Expression: + """The decorrelated window-function fallback for non-LATERAL targets. + + Rewrites the surrounding `` AS a CROSS JOIN LATERAL (nearest) AS b`` + into `` AS a JOIN () AS b ON AND + b. <= k``. The ranked subquery cross-joins the target against the outer + relation and ranks candidates per distinct reference key with + ``ROW_NUMBER()``; the join re-associates the top-k back to every outer row + sharing that key, reproducing the per-row LATERAL semantics. Swaps the parent + ``LATERAL`` for the decorrelated subquery in place and returns the (now + detached) NEAREST node, so the pass's own ``node.replace`` is a no-op. + + The no-op return relies on NEAREST having no nestable inner GIQL operator: a + detached node carrying a still-pending descendant would strand that + descendant's later ``node.replace``. NEAREST's only operands are a registered + target table and an interval reference, neither of which is an expandable + operator, so nothing pending hangs off the node this detaches. + """ + lateral = expression.parent + # Internal invariants the surrounding-AST rewrite depends on. The fallback + # only runs for a correlated NEAREST, whose pass-1 placement is always a CROSS + # JOIN LATERAL under a JOIN; a violation is an internal pipeline bug, not user + # error, so fail loudly with a clear message rather than dereferencing None. + # (The LATERAL *alias* is NOT an internal invariant — it is optional user + # input — so it is synthesized below rather than asserted; see B3.) + assert isinstance(lateral, exp.Lateral), ( + "correlated NEAREST fallback expected its parent to be a LATERAL, got " + f"{type(lateral).__name__}" + ) + join = lateral.parent + assert isinstance(join, exp.Join), ( + "correlated NEAREST fallback expected the LATERAL to sit under a JOIN, got " + f"{type(join).__name__}" + ) + # The decorrelated join references the LATERAL's alias on both sides (its ON + # clause and the replacement subquery's name). A correlated NEAREST written + # *without* a LATERAL alias is legitimate user input that transpiles fine on + # lateral-capable engines, so the fallback must not require one: synthesize a + # collision-safe alias from the run's sequence when the LATERAL carries none. + # This closes the DuckDB/DataFusion behavior gap (and, unlike an ``assert``, + # survives ``python -O``, which would otherwise strip the guard and leave a + # ``NoneType`` deref). The synthesized name is internal to the rewritten join. + lateral_alias = lateral.args.get("alias") + if lateral_alias is None or not lateral_alias.name: + alias = ctx.alias() + alias_node = exp.TableAlias(this=exp.to_identifier(alias)) + else: + alias = lateral_alias.name + alias_node = lateral_alias.copy() + + relation, outer_alias = _outer_relation(ref) + k_value, _max, is_stranded, _signed = _nearest_params(expression) + # Bare target column names: the candidate subquery exposes the target row via + # ``target.*``, so its tiebreaker columns are referenced by name, not by the + # ``table_name."col"`` qualifier the distance math uses. + _target_chrom, target_start_col, target_end_col = target_ref.cols + + # Pre-project the outer relation's reference columns under fresh, non-target + # names into a renamed derived relation. Cross-joining *this* (rather than the + # raw outer table) keeps every reference column distinct from the target's + # columns: DataFusion's planner cannot resolve a window ordering over a join + # whose two sides share column names (e.g. both expose ``start`` / ``end``). + # + # The reference key identifies one distinct reference interval, which the + # ranking partitions by and the join re-associates on. Position + # (chrom/start/end) alone identifies it in the unstranded case; in stranded + # mode strand joins the key too, because two outer rows at the same position + # but opposite strands must each get their own strand-filtered nearest. The + # ref relation is de-duplicated on the key with DISTINCT so ranking happens + # once per distinct reference and the join fans the top-k back out to every + # outer row sharing it — exactly the per-row LATERAL semantics, even when the + # outer table holds duplicate reference rows. + # Mint the synthetic relation aliases from the run's collision-safe sequence + # (rather than hardcoding ``__giql_x_ref`` / ``__giql_x_cand``) so two NEAREST + # operators in one query never reuse the same derived-relation name. The + # reserved *column* names below stay derived from EXPAND_ALIAS_PREFIX. + ref_relation_alias = ctx.alias() + candidate = ctx.alias() + strand_name = f"{_REF_KEY_PREFIX}strand" + stranded_key = is_stranded and ref.strand is not None + + key_names = [f"{_REF_KEY_PREFIX}chrom", f"{_REF_KEY_PREFIX}start", + f"{_REF_KEY_PREFIX}end"] + source_frags = [ref.chrom, ref.start, ref.end] + if stranded_key: + key_names.append(strand_name) + source_frags.append(ref.strand) + + ref_projection = ", ".join( + f'{frag} AS "{name}"' for name, frag in zip(key_names, source_frags) + ) + ref_relation = ( + f"(SELECT DISTINCT {ref_projection} FROM {relation} AS {outer_alias})" + f" AS {ref_relation_alias}" + ) + + # Reference fragments now point at the renamed relation's safe columns. + renamed = [f'{ref_relation_alias}."{name}"' for name in key_names] + renamed_strand = ( + f'{ref_relation_alias}."{strand_name}"' if stranded_key else None + ) + ref_fragments = (renamed[0], renamed[1], renamed[2], renamed_strand) + + ( + distance_expr, + _abs_distance_expr, + where_clauses, + passthrough, + ) = _distance_and_filters( + expression, table_name, target_ref, ref, ref_fragments=ref_fragments + ) + + # Surface the reference-key columns so the rewritten join can match each + # ranked candidate back to its outer row(s). Ranking depends only on these + # values, so partitioning by them and re-joining is identical to the per-row + # LATERAL form even when outer rows share a reference value. + key_cols = list(zip(key_names, renamed)) + key_projection = ", ".join(f'{frag} AS "{name}"' for name, frag in key_cols) + where_sql = " AND ".join(where_clauses) + + # Compute the candidate set (cross join + distance + reference keys) in an + # inner subquery, then add ROW_NUMBER() in the enclosing one. Keeping the + # join and the window in *separate* query levels is load-bearing on + # DataFusion: fused into one level its optimizer mis-derives the window's sort + # order from the chromosome-equality prefilter and trips ``SanityCheckPlan``. + inner = ( + f"SELECT {passthrough}, {distance_expr} AS distance, {key_projection} " + f"FROM {table_name} CROSS JOIN {ref_relation} " + f"WHERE {where_sql}" + ) + partition = ", ".join(f'{candidate}."{name}"' for name, _ in key_cols) + # A ``(start, end)`` tiebreaker follows ``ABS(distance)`` so rows tied at the + # k-th distance rank identically here and in the LATERAL form when + # ``(start, end)`` distinguishes them, making the two set-equivalent up to ties + # among candidates sharing both distance and coordinates (no engine-dependent + # tie ordering otherwise). + ranked = ( + f"(SELECT {candidate}.*, " + f"ROW_NUMBER() OVER (PARTITION BY {partition} " + f"ORDER BY ABS({candidate}.distance), " + f'{candidate}."{target_start_col}", {candidate}."{target_end_col}") ' + f'AS "{_RANK_COL}" ' + f"FROM ({inner}) AS {candidate})" + ) + ranked_subquery = parse_one(ranked, dialect=GIQLDialect) + + # Match each ranked candidate back to the *outer* relation by its reference + # value (the original outer-qualified fragments, e.g. ``a."chrom"``), not the + # renamed inner columns which exist only inside the subquery. + on_parts = [ + f'{alias}."{name}" = {src}' for name, src in zip(key_names, source_frags) + ] + on_parts.append(f'{alias}."{_RANK_COL}" <= {k_value}') + on_sql = " AND ".join(on_parts) + on_expr = parse_one(on_sql, dialect=GIQLDialect) + + subquery = exp.Subquery(this=ranked_subquery.this, alias=alias_node) + + # Convert `` CROSS JOIN LATERAL (nearest) AS b`` into + # `` JOIN (ranked) AS b ON AND b.rn <= k``. Swap the + # whole LATERAL out for the decorrelated subquery and drop the CROSS kind so + # the ON clause attaches as a plain (inner) join. + lateral.replace(subquery) + join.set("kind", None) + join.set("side", None) + join.set("on", on_expr) + + # The LATERAL (and the NEAREST node within it) is now detached; returning the + # node unchanged makes the pass's ``node.replace`` a no-op. + return expression + + +def expand_nearest(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a NEAREST node to LATERAL or the decorrelated window-function form. + + Selects on ``ctx.capabilities.supports_lateral`` and whether the node is + correlated (its parent is a ``LATERAL``). Lateral-capable targets and every + standalone (literal-reference) placement get the portable LATERAL/standalone + subquery; a correlated NEAREST on a target without LATERAL support gets the + decorrelated window-function fallback. + """ + assert isinstance(node, GIQLNearest) + resolution = ctx.resolution + + target_ref = resolution.slot("this") if resolution is not None else None + if not isinstance(target_ref, ResolvedRef): + # An unresolved target means it is not a registered table; raise the + # historical diagnostic (verbatim from the removed giqlnearest_sql). + target = node.this + if isinstance(target, exp.Table): + target_name = target.name + elif isinstance(target, exp.Column): + target_name = target.table if target.table else str(target.this) + else: + target_name = str(target) + raise ValueError( + f"Target table '{target_name}' not found in tables. " + "Register the table before transpiling." + ) + table_name = target_ref.name + + ref = resolution.slot("reference") + if not isinstance(ref, ResolvedInterval): + mode = BaseGIQLGenerator._detect_nearest_mode(node) + BaseGIQLGenerator._raise_nearest_reference_error(node, mode, resolution) + + # A literal-range reference is uncorrelated even under CROSS JOIN LATERAL: its + # endpoints are constants, not outer-row columns, so the subquery stands alone + # and every target — DataFusion included — takes the LATERAL/standalone form. + # Only a genuinely correlated reference (a column / implicit-outer endpoint) + # needs the decorrelated window-function fallback on a lateral-incapable + # target. Gating on parentage alone would mis-route a literal range into + # ``_fallback_form``, which dereferences a non-existent outer relation. + correlated = isinstance(node.parent, exp.Lateral) and ref.kind != "literal_range" + if correlated and not ctx.capabilities.supports_lateral: + return _fallback_form(node, ctx, table_name, target_ref, ref) + return _lateral_form(node, ctx, table_name, target_ref, ref) + + +# The generic registration covers every target through the registry's fallback +# chain; the expander branches on ctx.capabilities.supports_lateral internally, +# so no per-target override is needed. +register(GenericTarget, GIQLNearest)(expand_nearest) diff --git a/src/giql/expressions.py b/src/giql/expressions.py index 4949fdd..2e8fd0c 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -384,7 +384,11 @@ class GIQLNearest(exp.Func): #: half-open) operands are left untouched and the emitted SQL stays #: byte-identical. GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators pass (epic #137, issue #142): NEAREST is + #: expanded by ``giql.expanders.nearest`` — the portable correlated LATERAL + #: subquery where ``supports_lateral`` holds, a decorrelated window-function + #: form otherwise. The legacy ``giqlnearest_sql`` emitter has been removed. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"registered_table"}), required=True), diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index cc494dc..599eb0d 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -3,13 +3,13 @@ from giql.canonical import decanonical_end from giql.canonical import decanonical_start +from giql.dialect import GIQLDialect from giql.expressions import GIQLDisjoin from giql.expressions import GIQLNearest from giql.range_parser import RangeParser from giql.resolver import META_KEY from giql.resolver import OperatorResolution from giql.resolver import ResolvedColumn -from giql.resolver import ResolvedInterval from giql.resolver import ResolvedRef from giql.table import Table from giql.table import Tables @@ -28,10 +28,6 @@ class BaseGIQLGenerator(Generator): compatibility with virtually all SQL databases. """ - # Most databases support LATERAL joins (PostgreSQL 9.3+, DuckDB 0.7.0+) - # SQLite does not support LATERAL, so it overrides this to False - SUPPORTS_LATERAL = True - def __init__(self, tables: Tables | None = None, **kwargs): super().__init__(**kwargs) self.tables = tables or Tables() @@ -43,172 +39,9 @@ def __init__(self, tables: Tables | None = None, **kwargs): # ``spatialsetpredicate_sql`` emitters or their ``_generate_spatial_*`` / # ``_predicate_operand`` helpers. - def giqlnearest_sql(self, expression: GIQLNearest) -> str: - """Generate SQL for NEAREST function. - - Detects mode (standalone vs correlated) and generates appropriate SQL: - - Standalone: Direct query with ORDER BY + LIMIT - - Correlated (LATERAL): Subquery for k-nearest neighbors - - :param expression: - GIQLNearest expression node - :return: - SQL string for NEAREST operation - """ - # Detect mode - mode = self._detect_nearest_mode(expression) - - # Unpack the resolution metadata attached by ResolveOperatorRefs (pass 1). - resolution = self._nearest_resolution(expression) - - # Target (already a registered-table ResolvedRef from the pass). An - # unresolved target means it is not a registered table; raise the - # historical diagnostic. - target_ref = resolution.slot("this") if resolution is not None else None - if not isinstance(target_ref, ResolvedRef): - target = expression.this - if isinstance(target, exp.Table): - target_name = target.name - elif isinstance(target, exp.Column): - target_name = target.table if target.table else str(target.this) - else: - target_name = str(target) - raise ValueError( - f"Target table '{target_name}' not found in tables. " - "Register the table before transpiling." - ) - table_name = target_ref.name - target_chrom, target_start, target_end = target_ref.cols - - # The target's *declared* encoding, which the passed-through target row - # (SELECT {table_name}.*) must round-trip back into. CanonicalizeCoordinates - # (pass 2) preserves it on the resolution when it wraps a non-canonical - # target in a __giql_canon_* CTE (the slot's own Table is then None); a - # canonical target is left unwrapped and its slot Table carries the - # (identity) encoding. The synthesized `distance` column is encoding- - # invariant (a count of bases) and needs no round-trip. - output_table = self._nearest_output_encoding(expression, target_ref) - passthrough = self._nearest_passthrough( - table_name, target_start, target_end, output_table - ) - - # Reference interval (a ResolvedInterval from the pass). An unresolved - # reference re-raises the generator's historical diagnostic. Input - # canonicalization is owned by CanonicalizeCoordinates (pass 2, issue - # #123): a literal range is already canonical, and a column / implicit- - # outer reference's endpoints are canonicalized in place by the pass, so - # the emitter consumes the fragments verbatim with no canonicalization. - ref = resolution.slot("reference") - if not isinstance(ref, ResolvedInterval): - self._raise_nearest_reference_error(expression, mode, resolution) - ref_chrom, ref_start, ref_end = ref.chrom, ref.start, ref.end - - # Extract parameters - k = expression.args.get("k") - k_value = int(str(k)) if k else 1 # Default k=1 - - max_distance = expression.args.get("max_distance") - max_dist_value = int(str(max_distance)) if max_distance else None - - is_stranded = self._extract_bool_param(expression.args.get("stranded")) - is_signed = self._extract_bool_param(expression.args.get("signed")) - - # Resolve strand columns if stranded mode. The reference strand is - # carried on the resolved interval (a literal's strand, an explicit - # column's strand, or the outer table's strand for an implicit - # reference — already gated to preserve the historical divergence). - ref_strand = None - target_strand = None - if is_stranded: - ref_strand = ref.strand - # When pass 2 wraps a non-canonical target its slot Table is blanked, - # so the strand column name comes from the *declared* encoding the - # pass preserved (output_table). The canon CTE's SELECT * REPLACE - # passes the strand column through unchanged under its physical name, - # so the qualifier stays the relation NEAREST selects from. - if output_table and output_table.strand_col: - target_strand = f'{table_name}."{output_table.strand_col}"' - - # Distance math below assumes 0-based half-open. Input canonicalization - # is owned by CanonicalizeCoordinates (pass 2, issue #123): a - # non-canonical target is rewritten to a canonical __giql_canon_* CTE - # before generation (table_name then names the CTE), so the target - # endpoints are consumed verbatim with no in-emitter canonicalization. The - # output round-trip of the passed-through target row stays here (see the - # SELECT projection below). - target_start_expr = f'{table_name}."{target_start}"' - target_end_expr = f'{table_name}."{target_end}"' - - # Build distance calculation using CASE expression - # For NEAREST: ORDER BY absolute distance, but RETURN signed distance - distance_expr = self._generate_distance_case( - ref_chrom, - ref_start, - ref_end, - ref_strand, - f'{table_name}."{target_chrom}"', - target_start_expr, - target_end_expr, - target_strand, - stranded=is_stranded, - signed=is_signed, - ) - - # Use absolute distance for ordering and filtering - abs_distance_expr = f"ABS({distance_expr})" - - # Build WHERE clauses - where_clauses = [ - f'{ref_chrom} = {table_name}."{target_chrom}"' # Chromosome pre-filter - ] - - # Add strand matching for stranded mode - if is_stranded and ref_strand and target_strand: - where_clauses.append(f"{ref_strand} = {target_strand}") - - if max_dist_value is not None: - where_clauses.append(f"({abs_distance_expr}) <= {max_dist_value}") - - where_sql = " AND ".join(where_clauses) - - # Generate SQL based on mode - if mode == "standalone": - # Standalone mode: direct ORDER BY + LIMIT - # Return signed distance, but order by absolute distance - sql = f"""( - SELECT {passthrough}, {distance_expr} AS distance - FROM {table_name} - WHERE {where_sql} - ORDER BY {abs_distance_expr} - LIMIT {k_value} - )""" - else: - # Correlated mode: requires LATERAL join support - if not self.SUPPORTS_LATERAL: - raise ValueError( - "NEAREST in correlated mode (CROSS JOIN LATERAL) is not supported " - "in SQLite. SQLite does not support LATERAL joins. " - "\n\nAlternatives:" - "\n1. Use standalone mode: SELECT * FROM NEAREST(table, " - "reference='chr1:100-200', k=3)" - "\n2. Use DuckDB for queries requiring LATERAL joins" - "\n3. Manually write equivalent window function query" - ) - - # LATERAL mode: subquery for k-nearest neighbors - # Return signed distance, but order by absolute distance - sql = f"""( - SELECT {passthrough}, {distance_expr} AS distance - FROM {table_name} - WHERE {where_sql} - ORDER BY {abs_distance_expr} - LIMIT {k_value} - )""" - - return sql.strip() - + @staticmethod def _nearest_output_encoding( - self, expression: GIQLNearest, target_ref: ResolvedRef + expression: GIQLNearest, target_ref: ResolvedRef ) -> Table | None: """Return the target's declared encoding for NEAREST's row passthrough. @@ -233,8 +66,8 @@ def _nearest_output_encoding( return preserved return target_ref.table + @staticmethod def _nearest_passthrough( - self, table_name: str, target_start: str, target_end: str, @@ -454,8 +287,8 @@ def _disjoin_passthrough( f't.* REPLACE ({pt_start} AS "{target_start}", {pt_end} AS "{target_end}")' ) + @staticmethod def _generate_distance_case( - self, chrom_a: str, start_a: str, end_a: str, @@ -563,8 +396,9 @@ def _generate_distance_case( f"ELSE ({start_a} - {end_b} + 1) END END" ) + @staticmethod def _detect_nearest_mode( - self, expression: GIQLNearest, parent_expression: exp.Expression | None = None + expression: GIQLNearest, parent_expression: exp.Expression | None = None ) -> str: """Detect whether NEAREST is in standalone or correlated mode. @@ -588,25 +422,8 @@ def _detect_nearest_mode( # (validation will catch missing reference errors later) return "correlated" - def _nearest_resolution(self, expression: GIQLNearest) -> OperatorResolution | None: - """Return the NEAREST resolution attached by ResolveOperatorRefs (pass 1). - - The transpile pipeline attaches an - :class:`~giql.resolver.OperatorResolution` before generation, and it - survives the generator's defensive tree copy. The emitter reads only the - attached metadata; resolution lives entirely in the pass. - - :param expression: - GIQLNearest expression node - :return: - The attached :class:`~giql.resolver.OperatorResolution`, or ``None`` - if resolution did not produce one. - """ - resolution = expression.meta.get(META_KEY) - return resolution if isinstance(resolution, OperatorResolution) else None - + @staticmethod def _raise_nearest_reference_error( - self, expression: GIQLNearest, mode: str, resolution: OperatorResolution | None, @@ -654,7 +471,7 @@ def _raise_nearest_reference_error( # An explicit reference that deferred is a literal range that failed to # parse (column references always resolve). Re-parse to surface the # original parse error in the historical message. - reference_sql = self.sql(reference) + reference_sql = reference.sql(dialect=GIQLDialect) range_str = reference_sql.strip("'\"") try: RangeParser.parse(range_str).to_zero_based_half_open() diff --git a/src/giql/targets.py b/src/giql/targets.py index 6fd29d3..825a6d2 100644 --- a/src/giql/targets.py +++ b/src/giql/targets.py @@ -29,11 +29,13 @@ class Capabilities: Parameters ---------- supports_lateral : bool - Whether the engine supports ``LATERAL`` / correlated joins. Will - drive the NEAREST LATERAL-vs-window-function strategy (#142). Until - then, :attr:`giql.generators.base.BaseGIQLGenerator.SUPPORTS_LATERAL` - remains the live source of truth at generation time; #142 reconciles - the two. + Whether the engine supports ``LATERAL`` / correlated joins. Drives the + NEAREST LATERAL-vs-window-function strategy (#142): a correlated NEAREST + expands to a portable correlated ``LATERAL`` subquery where this holds + and to a decorrelated window-function form where it does not. This + capability is the single source of truth — the former + ``BaseGIQLGenerator.SUPPORTS_LATERAL`` generator attribute has been + removed. supports_star_replace : bool Whether the engine supports ``SELECT * REPLACE (...)``. Drives the coordinate-canonicalization output: ``* REPLACE`` where supported, diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index 9c8f8b0..3ac63c0 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -8,7 +8,6 @@ from hypothesis import given from hypothesis import settings from hypothesis import strategies as st -from sqlglot import exp from sqlglot import parse_one import giql.expanders # noqa: F401 (registers built-in expanders) @@ -17,7 +16,6 @@ from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect from giql.expander import ExpandOperators -from giql.expressions import GIQLNearest from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables @@ -28,14 +26,14 @@ def _generate_through_passes(sql: str, tables: Tables) -> str: """Parse, run normalization passes 1-3, then generate SQL. Coordinate canonicalization for operator operands moved out of the emitter and - into the CanonicalizeCoordinates pass (issue #123), and DISTANCE (issue #140) - and the spatial / set predicates (issue #141) moved onto the registry's + into the CanonicalizeCoordinates pass (issue #123), and DISTANCE (#140), the + spatial / set predicates (#141), and NEAREST (#142) moved onto the registry's ExpandOperators pass (epic #137). Emitter-level tests that pin canonicalized / expanded output must therefore run all three passes before generating, exactly as :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a - bare parsed AST. Operators still on the legacy emitter (NEAREST, DISJOIN) pass - through the expansion pass untouched. This helper is used where the full - ``transpile`` pipeline would otherwise rewrite the node away (a column-to-column + bare parsed AST. Operators still on the legacy emitter (DISJOIN) pass through + the expansion pass untouched. This helper is used where the full ``transpile`` + pipeline would otherwise rewrite the node away (a column-to-column ``INTERSECTS`` is turned into a binned equi-join before the predicate expander runs). """ @@ -110,13 +108,12 @@ def test_instantiation_defaults(self): """ GIVEN no tables provided WHEN Generator is instantiated with defaults - THEN Generator has empty Tables and SUPPORTS_LATERAL is True. + THEN Generator has empty Tables. """ generator = BaseGIQLGenerator() assert generator.tables is not None assert "variants" not in generator.tables - assert generator.SUPPORTS_LATERAL is True def test_instantiation_with_tables(self, tables_info): """ @@ -402,41 +399,43 @@ def test_spatialsetpredicate_sql_all(self): ) assert output == expected - def test_giqlnearest_sql_standalone(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_ordered_limit_subquery_when_standalone( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest in standalone mode with literal reference - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Subquery with ORDER BY distance LIMIT k is generated. """ sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # NEAREST now expands via its registered expander (#142): a two-level + # subquery — an inner SELECT that materializes ``distance`` and an outer + # wrapper that orders on the materialized ``__giql_x_0."distance"`` plus a + # deterministic ``(start, end)`` tiebreaker (#142 A5). Splitting the + # distance computation from the ordering keeps DuckDB's correlated-LATERAL + # binder and DataFusion's planner both happy while staying result- + # equivalent to the legacy single-level emitter. expected = ( - "SELECT * FROM (\n" - " SELECT genes.*, " - "CASE WHEN 'chr1' != genes.\"chrom\" THEN NULL " + "SELECT * FROM (SELECT * FROM (SELECT genes.*, " + "CASE WHEN 'chr1' <> genes.\"chrom\" THEN NULL " 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' - 'WHEN 2000 <= genes."start" ' - 'THEN (genes."start" - 2000 + 1) ' - 'ELSE (1000 - genes."end" + 1) END AS distance\n' - " FROM genes\n" - " WHERE 'chr1' = genes.\"chrom\"\n" - " ORDER BY ABS(" - "CASE WHEN 'chr1' != genes.\"chrom\" THEN NULL " - 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' - 'WHEN 2000 <= genes."start" ' - 'THEN (genes."start" - 2000 + 1) ' - 'ELSE (1000 - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'WHEN 2000 <= genes."start" THEN (genes."start" - 2000 + 1) ' + 'ELSE (1000 - genes."end" + 1) END AS distance ' + "FROM genes WHERE 'chr1' = genes.\"chrom\") AS __giql_x_0 " + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_lateral_subquery_when_correlated( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest in correlated mode (LATERAL join context) - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs on a lateral-capable target THEN LATERAL-compatible subquery is generated. """ sql = ( @@ -446,33 +445,30 @@ def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery: the inner + # SELECT materializes ``distance`` (CASE and WHERE semantically unchanged) + # and the outer wrapper orders on the materialized + # ``__giql_x_0."distance"`` plus a deterministic ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_filter_on_max_distance( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with max_distance parameter - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN WHERE clause includes distance filter. """ sql = ( @@ -483,40 +479,35 @@ def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery; the + # max_distance filter on ABS of the distance CASE stays in the inner + # SELECT's WHERE and the outer wrapper orders on the materialized + # ``__giql_x_0."distance"`` plus the deterministic ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom" ' - "AND (ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'ELSE (peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'AND (ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)\n' - " LIMIT 5\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000) ' + 'AS __giql_x_0 ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 5)' ) assert output == expected - def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_match_strand_when_stranded( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with stranded := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand matching is included in WHERE clause. """ sql = ( @@ -527,45 +518,33 @@ def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery; the stranded + # distance CASE and the ``peaks.strand = genes.strand`` match in the inner + # WHERE are semantically unchanged, with the outer wrapper ordering on the + # materialized ``__giql_x_0."distance"`` plus the ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " "WHEN genes.\"strand\" = '.' OR genes.\"strand\" = '?' THEN NULL " - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' "THEN CASE WHEN peaks.\"strand\" = '-' " 'THEN -(genes."start" - peaks."end" + 1) ' 'ELSE (genes."start" - peaks."end" + 1) END ' "ELSE CASE WHEN peaks.\"strand\" = '-' " 'THEN -(peaks."start" - genes."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom" ' - 'AND peaks."strand" = genes."strand"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' - "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " - "WHEN genes.\"strand\" = '.' OR genes.\"strand\" = '?' THEN NULL " - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - "THEN CASE WHEN peaks.\"strand\" = '-' " - 'THEN -(genes."start" - peaks."end" + 1) ' - 'ELSE (genes."start" - peaks."end" + 1) END ' - "ELSE CASE WHEN peaks.\"strand\" = '-' " - 'THEN -(peaks."start" - genes."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END END)\n' - " LIMIT 3\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'AND peaks."strand" = genes."strand") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_implicit_outer_without_strand_column(self): + def test_expand_nearest_should_skip_strand_when_outer_has_no_strand_column(self): """ GIVEN a stranded NEAREST whose implicit-outer table declares no strand column @@ -589,10 +568,12 @@ def test_giqlnearest_sql_implicit_outer_without_strand_column(self): assert "strand" not in output assert 'nostr."chrom" = genes."chrom"' in output - def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_signed_distance_when_signed( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with signed := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Distance expression includes signed calculation. """ sql = ( @@ -603,57 +584,37 @@ def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery; the signed + # distance CASE (negated ELSE branch for upstream) is semantically + # unchanged in the inner SELECT, with the outer wrapper ordering on the + # materialized ``__giql_x_0."distance"`` plus the ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE -(peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE -(peaks."start" - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'ELSE -(peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_no_lateral_support(self, tables_with_peaks_and_genes): - """ - GIVEN a GIQLNearest on a generator with SUPPORTS_LATERAL=False - WHEN giqlnearest_sql is called in correlated mode - THEN ValueError is raised with helpful message. - """ - - # Create a generator subclass without LATERAL support - class NoLateralGenerator(BaseGIQLGenerator): - SUPPORTS_LATERAL = False - - # Use query without explicit reference to trigger correlated mode - sql = "SELECT * FROM peaks CROSS JOIN LATERAL NEAREST(genes, k := 3)" - ast = parse_one(sql, dialect=GIQLDialect) - ast = resolve_operator_refs(ast, tables_with_peaks_and_genes) - ast = canonicalize_coordinates(ast) - - generator = NoLateralGenerator(tables=tables_with_peaks_and_genes) - - with pytest.raises(ValueError, match="LATERAL"): - generator.generate(ast) + # The legacy ``SUPPORTS_LATERAL=False`` generator-level error path was removed + # with ``giqlnearest_sql`` (#142): lateral support is now a target capability, + # and a target without it (DataFusion) gets the decorrelated window-function + # fallback rather than a hard error. That fallback's result-identity with the + # LATERAL form is verified by the cross-target oracle + # (tests/integration/datafusion/test_cross_target_oracle.py). @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) @given( k=st.integers(min_value=1, max_value=100), max_distance=st.integers(min_value=1, max_value=10_000_000), ) - def test_giqlnearest_sql_parameter_handling_property( + def test_expand_nearest_should_carry_k_and_max_distance_property( self, tables_with_peaks_and_genes, k, max_distance ): """ @@ -866,12 +827,12 @@ def test_select_sql_join_without_alias(self, tables_with_two_tables): ) assert output == expected - def test_giqlnearest_sql_stranded_literal_with_strand( + def test_expand_nearest_should_use_literal_strand_when_stranded( self, tables_with_peaks_and_genes ): """ GIVEN a GIQLNearest with stranded := true and literal reference containing strand - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand from literal range is parsed and used in filtering. """ sql = ( @@ -885,12 +846,12 @@ def test_giqlnearest_sql_stranded_literal_with_strand( assert "'+'" in output assert 'genes."strand"' in output - def test_giqlnearest_sql_stranded_implicit_reference( + def test_expand_nearest_should_resolve_outer_strand_when_implicit_reference( self, tables_with_peaks_and_genes ): """ GIVEN a GIQLNearest in correlated mode with implicit reference and stranded := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand column is resolved from outer table and used. """ sql = "SELECT * FROM peaks CROSS JOIN LATERAL NEAREST(genes, k := 3, stranded := true)" @@ -989,29 +950,22 @@ def test_giqldistance_sql_literal_second_arg_error(self, tables_with_two_tables) with pytest.raises(ValueError, match="Literal range as second argument"): expander.transform(ast) - def test_giqlnearest_sql_missing_outer_table_error( + def test_expand_nearest_should_raise_when_outer_table_unresolvable( self, tables_with_peaks_and_genes ): """ - GIVEN a GIQLNearest in correlated mode without reference where outer table - cannot be found - WHEN giqlnearest_sql is called - THEN ValueError is raised with helpful message about specifying reference. + GIVEN a GIQLNearest without a reference and no resolvable outer table + WHEN the NEAREST expander runs + THEN ValueError is raised with a helpful message about specifying reference. """ + # Arrange — no reference and no LATERAL outer relation to infer one from. + sql = "SELECT * FROM NEAREST(genes, k := 3)" - nearest = GIQLNearest( - this=exp.Table(this=exp.Identifier(this="genes")), - k=exp.Literal.number(3), - ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - + # Act & assert with pytest.raises(ValueError, match="Could not find outer table"): - generator.giqlnearest_sql(nearest) + _generate_through_passes(sql, tables_with_peaks_and_genes) - def test_giqlnearest_sql_outer_table_not_in_tables(self): + def test_expand_nearest_should_raise_when_outer_table_unregistered(self): """ GIVEN a NEAREST whose implicit-outer relation is found but not registered WHEN the query is generated @@ -1027,10 +981,12 @@ def test_giqlnearest_sql_outer_table_not_in_tables(self): with pytest.raises(ValueError, match="not found in tables"): _generate_through_passes(sql, tables) - def test_giqlnearest_sql_invalid_reference_range(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_raise_when_reference_range_unparseable( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with invalid/unparseable reference range string - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN ValueError is raised with parse error details. """ sql = "SELECT * FROM NEAREST(genes, reference := 'invalid_range', k := 3)" @@ -1038,36 +994,35 @@ def test_giqlnearest_sql_invalid_reference_range(self, tables_with_peaks_and_gen with pytest.raises(ValueError, match="Could not parse reference genomic range"): _generate_through_passes(sql, tables_with_peaks_and_genes) - def test_giqlnearest_sql_no_tables_error(self): + def test_expand_nearest_should_raise_when_no_tables_registered(self): """ - GIVEN a GIQLNearest without tables registered - WHEN giqlnearest_sql is called - THEN ValueError is raised because target table cannot be resolved. + GIVEN a GIQLNearest with no tables registered + WHEN the NEAREST expander runs + THEN ValueError is raised because the target table cannot be resolved. """ + # Arrange sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" - ast = parse_one(sql, dialect=GIQLDialect) - - # Generator with empty tables - table won't be found - generator = BaseGIQLGenerator() + # Act & assert with pytest.raises(ValueError, match="not found in tables"): - generator.generate(ast) + _generate_through_passes(sql, Tables()) - def test_giqlnearest_sql_target_not_in_tables(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_raise_when_target_unregistered( + self, tables_with_peaks_and_genes + ): """ - GIVEN a GIQLNearest with target table not registered - WHEN giqlnearest_sql is called - THEN ValueError is raised listing available tables. + GIVEN a GIQLNearest whose target table is not registered + WHEN the NEAREST expander runs + THEN ValueError is raised listing the unresolved table. """ + # Arrange sql = ( "SELECT * FROM NEAREST(unknown_table, reference := 'chr1:1000-2000', k := 3)" ) - ast = parse_one(sql, dialect=GIQLDialect) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) + # Act & assert with pytest.raises(ValueError, match="not found in tables"): - generator.generate(ast) + _generate_through_passes(sql, tables_with_peaks_and_genes) def test_intersects_sql_unqualified_column(self): """ @@ -1085,57 +1040,47 @@ def test_intersects_sql_unqualified_column(self): ) assert output == expected - def test_giqlnearest_sql_stranded_unqualified_reference( + def test_expand_nearest_should_resolve_strand_when_reference_unqualified( self, tables_with_peaks_and_genes ): """ - GIVEN a GIQLNearest with stranded := true and unqualified column reference - WHEN giqlnearest_sql is called + GIVEN a GIQLNearest with stranded := true and an unqualified column reference + WHEN the NEAREST expander runs THEN Strand column is resolved without table prefix. """ - - # Create NEAREST with stranded=True and an unqualified column reference - # The reference is an unqualified column (no table prefix) - nearest = GIQLNearest( - this=exp.Table(this=exp.Identifier(this="genes")), - reference=exp.Column(this=exp.Identifier(this="interval")), - k=exp.Literal.number(3), - stranded=exp.Boolean(this=True), + # Arrange — the reference is an unqualified column (no table prefix). + sql = ( + "SELECT * FROM peaks CROSS JOIN LATERAL " + "NEAREST(genes, reference := interval, k := 3, stranded := true)" ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - output = generator.giqlnearest_sql(nearest) + # Act + output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Should produce valid output with unqualified strand column + # Assert assert "LIMIT 3" in output - # The strand column should be unqualified (no table prefix) assert '"strand"' in output - def test_giqlnearest_sql_identifier_target(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_ordered_subquery_for_literal_reference( + self, tables_with_peaks_and_genes + ): """ - GIVEN a GIQLNearest where target is an Identifier (not Table or Column) - WHEN giqlnearest_sql is called - THEN Target is converted to string and lookup proceeds. + GIVEN a GIQLNearest with a standalone literal reference + WHEN the NEAREST expander runs + THEN it produces a standalone ordered, limited subquery over the target + table with no correlated LATERAL. """ + # Arrange + sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" - # Use exp.Identifier directly - not Table or Column - # This triggers the else branch at line 830 where str(target) is called - nearest = GIQLNearest( - this=exp.Identifier(this="genes"), - reference=exp.Literal.string("chr1:1000-2000"), - k=exp.Literal.number(3), - ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - output = generator.giqlnearest_sql(nearest) + # Act + output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Should succeed and produce valid SQL + # Assert assert "genes" in output + assert "ORDER BY" in output assert "LIMIT 3" in output + assert "LATERAL" not in output @given( bool_repr=st.sampled_from(["true", "TRUE", "True", "1", "yes", "YES"]), @@ -1808,7 +1753,7 @@ def test_giqlnearest_should_canonicalize_reference_column_when_reference_is_one_ A 0-based half-open target table (bed_a) and an explicit reference column from a 1-based closed table (vcf_b). When: - giqlnearest_sql is called. + the NEAREST expander runs. Then: It should wrap the reference-side start as (start - 1), leave its end raw, and leave the target side raw. @@ -1841,7 +1786,7 @@ def test_giqlnearest_should_canonicalize_outer_table_columns_when_reference_is_i target table (bed_a) joined via CROSS JOIN LATERAL with no ``reference`` argument on NEAREST. When: - giqlnearest_sql is called. + the NEAREST expander runs. Then: It should canonicalize the outer table's columns based on vcf_b's convention — wrapping start as (vcf_b."start" - 1) and diff --git a/tests/integration/datafusion/test_cross_target_oracle.py b/tests/integration/datafusion/test_cross_target_oracle.py index 3d4b864..f67ff9b 100644 --- a/tests/integration/datafusion/test_cross_target_oracle.py +++ b/tests/integration/datafusion/test_cross_target_oracle.py @@ -14,13 +14,20 @@ genuinely divergent SQL across targets (the DuckDB IEJoin vs. the binned equi-join). -NEAREST's expansion uses a correlated ``LATERAL`` subquery, which DataFusion has -no physical plan for today; its generic-vs-duckdb equivalence case runs both on -DuckDB, and the full three-target oracle is pinned by a -``pytest.raises(match="OuterReferenceColumn")`` test (#142) that fails loudly on -an unrelated error and trips "DID NOT RAISE" — forcing conversion to a real -identity test — when DataFusion gains correlated LATERAL. DISJOIN has an -analogous pending-#153 gap (duplicate ``end`` output names). +NEAREST's correlated expansion is capability-driven (#142): lateral-capable +targets (generic, duckdb) emit the portable ``LATERAL`` subquery, while +DataFusion — which has no correlated-LATERAL physical plan — gets a decorrelated +window-function fallback. For **explicitly-projected** queries (those selecting +named columns) the two forms return identical rows, so the full three-target +identity oracle now runs on every target for that projection shape (the former +``_unsupported_pending_142`` ``pytest.raises`` pin has been promoted to a real +identity test). The identity claim is narrowed to explicit projections because a +``SELECT *`` / ``SELECT b.*`` over a correlated NEAREST on DataFusion additionally +exposes the fallback's reserved ``__giql_x_*`` rank/key columns, a divergent +output schema from the LATERAL form's — a known limitation pinned by the +``xfail`` ``test_correlated_nearest_star_projection_diverges_on_datafusion`` case +below and tracked for a query-level fix by #160 (dependent on #146). DISJOIN has +an analogous pending-#153 gap (duplicate ``end`` output names). """ import pytest @@ -204,32 +211,384 @@ def test_standalone_nearest_k1_agrees_generic_vs_duckdb_on_duckdb( engines={"generic": "duckdb"}, ) - def test_nearest_on_datafusion_unsupported_pending_142(self, cross_target_oracle): - """Test the full NEAREST oracle raises DataFusion's missing-LATERAL error. + def test_correlated_nearest_k1_agrees_across_all_targets(self, cross_target_oracle): + """Test correlated NEAREST k=1 returns identical rows on every target. Given: - The single-row NEAREST query and a candidate gene on chr1. + A single-row peaks table and three candidate genes at varying + distances on chr1. When: - The oracle runs all three targets — the datafusion target executes - the correlated LATERAL on DataFusion, which has no physical plan. + A correlated ``CROSS JOIN LATERAL NEAREST(..., k := 1)`` query runs + for every target — the generic and duckdb targets emit the portable + LATERAL form (executed on DuckDB, the lateral-capable engine), and + the datafusion target emits the decorrelated window-function fallback + the #142 expander produces (executed on DataFusion). Then: - DataFusion should raise its ``OuterReferenceColumn`` "not - implemented" error. This pins the known #142 gap: the ``match`` - narrows to the LATERAL signature so an unrelated/reworded DataFusion - error fails loudly, and a closed gap (no exception) trips pytest's - "DID NOT RAISE", forcing this to be converted into a real - cross-target identity test when DataFusion gains correlated LATERAL. + Every target should return the single nearest gene and agree. + + Promoted from the ``_unsupported_pending_142`` expected-failure pin: + DataFusion now plans correlated NEAREST through the capability-driven + window-function fallback, so the full three-target oracle is a real + identity test rather than a ``pytest.raises`` placeholder. The generic + target is routed to DuckDB because its portable SQL is the LATERAL form, + which only the datafusion-specific fallback decorrelates for DataFusion. """ # Arrange / Act / Assert - with pytest.raises(Exception, match="OuterReferenceColumn"): - cross_target_oracle( - "SELECT a.chrom, a.start AS a_start, b.start AS b_start " - "FROM peaks a " - "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", - peaks=[("chr1", 200, 300)], - genes=[("chr1", 280, 290)], - expected=[("chr1", 200, 280)], - ) + cross_target_oracle( + "SELECT a.chrom, a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ], + expected=[("chr1", 200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_k2_returns_two_nearest_across_targets( + self, cross_target_oracle + ): + """Test correlated NEAREST k=2 picks the two nearest on every target. + + Given: + One peak and four candidate genes, more than k of them on the peak's + chromosome at distinct distances. + When: + A correlated ``NEAREST(..., k := 2)`` runs on every target — DuckDB + via the LATERAL form, DataFusion via the decorrelated window fallback. + Then: + Every target should return the two nearest genes and agree, pinning + the top-k fan-out of the fallback against the LATERAL form. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 2) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ("chr1", 310, 320), + ], + expected=[(200, 280), (200, 310)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_duplicate_reference_rows_fan_out( + self, cross_target_oracle + ): + """Test correlated NEAREST fans the top-k out to duplicate reference rows. + + Given: + Two identical peak rows and two candidate genes. + When: + A correlated ``NEAREST(..., k := 1)`` runs on every target. + Then: + Every target should return the nearest gene once per duplicate peak + (two rows), pinning the fallback's DISTINCT-then-rejoin fan-out so a + duplicate outer row is not collapsed. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300), ("chr1", 200, 300)], + genes=[("chr1", 280, 290), ("chr1", 50, 60)], + expected=[(200, 280), (200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_partitions_by_chromosome(self, cross_target_oracle): + """Test correlated NEAREST keys the nearest per outer chromosome. + + Given: + Peaks on chr1 and chr2 and candidate genes on both chromosomes. + When: + A correlated ``NEAREST(..., k := 1)`` runs on every target. + Then: + Each peak should pair with the nearest gene on its own chromosome and + all targets agree, pinning the fallback's PARTITION BY reference key + across distinct outer keys. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.chrom AS chrom, a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300), ("chr2", 200, 300)], + genes=[("chr1", 280, 290), ("chr2", 500, 510), ("chr2", 205, 215)], + expected=[("chr1", 200, 280), ("chr2", 200, 205)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_max_distance_boundary(self, cross_target_oracle): + """Test correlated NEAREST drops candidates beyond max_distance everywhere. + + Given: + A peak and two genes, one just inside and one far beyond a + ``max_distance`` threshold. + When: + A correlated ``NEAREST(..., k := 5, max_distance := 100)`` runs on + every target. + Then: + Every target should return only the in-threshold gene, pinning the + ``max_distance`` filter through both the LATERAL and fallback forms. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 5, max_distance := 100) b", + peaks=[("chr1", 200, 300)], + genes=[("chr1", 360, 400), ("chr1", 5000, 5100)], + expected=[(200, 360)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_stranded_matches_strand(self, cross_target_oracle): + """Test stranded correlated NEAREST matches strand on every target. + + Given: + A ``+`` peak and two genes — a slightly farther ``+`` gene and a + nearer ``-`` gene. + When: + A correlated ``NEAREST(..., k := 1, stranded := true)`` runs on every + target. + Then: + Every target should return the same-strand (``+``) gene even though + the opposite-strand gene is nearer, in agreement. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 1, stranded := true) b", + tables=[Table("peaks"), Table("genes")], + columns=_STRANDED_COLUMNS, + peaks=[("chr1", 200, 300, "+")], + genes=[("chr1", 280, 290, "+"), ("chr1", 250, 260, "-")], + expected=[(200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_signed_distance_agrees(self, cross_target_oracle): + """Test signed correlated NEAREST reports signed distances everywhere. + + Given: + A peak with one upstream and one downstream candidate gene. + When: + A correlated ``NEAREST(..., k := 2, signed := true)`` projects the + ``distance`` column on every target. + Then: + Every target should report a negative distance for the upstream gene + and a positive one for the downstream gene, in agreement. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start, b.distance AS d " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 2, signed := true) b", + peaks=[("chr1", 200, 300)], + genes=[("chr1", 50, 60), ("chr1", 360, 400)], + expected=[(200, 50, -141), (200, 360, 61)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_k_th_distance_tie_breaks_on_coordinates( + self, cross_target_oracle + ): + """Test the (start, end) tiebreaker picks the same k-th candidate everywhere. + + Given: + One peak and three genes where two candidates are tied at the k-th + (k=1) distance — both 100 bp away, one upstream and one downstream of + the peak — so only the ``(start, end)`` tiebreaker can choose between + them (the lower ``(start, end)`` wins). + When: + A correlated ``NEAREST(..., k := 1)`` runs on every target — DuckDB via + the LATERAL form's ``ORDER BY ABS(distance), start, end LIMIT 1`` and + DataFusion via the fallback's matching ``ROW_NUMBER()`` ordering. + Then: + Every target should return the lower-coordinate tied candidate (the + upstream gene), so the LATERAL and window forms agree on the tie rather + than ordering it by engine-dependent chance. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300)], + # Upstream gene ends at 100 (gap 100); downstream gene starts at 400 + # (gap 100). Both tie at distance 100; (start, end) breaks the tie in + # favor of the upstream gene (start 50 < start 400). + genes=[("chr1", 50, 100), ("chr1", 400, 450)], + expected=[(200, 50)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_stranded_opposite_strands_same_position( + self, cross_target_oracle + ): + """Test stranded NEAREST keys per-strand for co-located opposite-strand rows. + + Given: + Two peaks at the *same* position but on opposite strands (``+`` and + ``-``), and one same-position ``+`` gene plus one ``-`` gene, so the + strand-augmented reference key must keep each outer row's nearest + strand-matched independently. + When: + A correlated ``NEAREST(..., k := 1, stranded := true)`` runs on every + target — DuckDB via the LATERAL form, DataFusion via the fallback whose + reference key includes strand. + Then: + The ``+`` peak should pair with the ``+`` gene and the ``-`` peak with + the ``-`` gene, in agreement, proving the fan-out keys by strand so two + co-located opposite-strand outer rows are not collapsed. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.strand AS a_strand, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 1, stranded := true) b", + tables=[Table("peaks"), Table("genes")], + columns=_STRANDED_COLUMNS, + peaks=[("chr1", 200, 300, "+"), ("chr1", 200, 300, "-")], + genes=[("chr1", 280, 290, "+"), ("chr1", 250, 260, "-")], + expected=[("+", 280), ("-", 250)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_max_distance_keeps_k_survivors( + self, cross_target_oracle + ): + """Test max_distance with k>1 keeps every in-threshold survivor everywhere. + + Given: + One peak and four genes where three sit within a ``max_distance`` + threshold at distinct distances and one sits beyond it, with k larger + than the survivor count. + When: + A correlated ``NEAREST(..., k := 3, max_distance := 200)`` runs on + every target — DuckDB via the LATERAL form, DataFusion via the window + fallback. + Then: + Every target should return exactly the three in-threshold genes (the + beyond-threshold gene dropped), proving ``max_distance`` and the top-k + survivor set interact identically across both forms. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 3, max_distance := 200) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 350, 360), + ("chr1", 420, 430), + ("chr1", 480, 490), + ("chr1", 5000, 5100), + ], + expected=[(200, 350), (200, 420), (200, 480)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_unaliased_lateral_agrees_across_targets( + self, cross_target_oracle + ): + """Test an unaliased correlated NEAREST agrees across targets (B3 on-engine). + + Given: + A correlated ``CROSS JOIN LATERAL NEAREST(...)`` written *without* a + table alias — legitimate GIQL that, before B3, raised on DataFusion + while running on DuckDB. + When: + The query runs on every target — DuckDB via the LATERAL form and + DataFusion via the decorrelated fallback, which now synthesizes the + missing alias instead of asserting one. + Then: + Every target should return the single nearest gene and agree, proving + the synthesized-alias fallback both transpiles and executes on the real + DataFusion engine. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1)", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ], + expected=[(200,)], + engines={"generic": "duckdb"}, + ) + + @pytest.mark.xfail( + strict=True, + reason="#160: SELECT b.* over a correlated NEAREST on DataFusion exposes " + "the decorrelated fallback's reserved __giql_x_rk_*/__giql_x_rn columns, " + "so the output schema diverges from the LATERAL form's on DuckDB. The " + "cross-target identity claim is narrowed to explicitly-projected queries " + "until #160 (dependent on #146) adds a query-level wrapper that projects " + "the reserved columns away on the DataFusion path; this flips to a real " + "identity test then.", + ) + def test_correlated_nearest_star_projection_diverges_on_datafusion( + self, cross_target_oracle + ): + """Test SELECT b.* over a correlated NEAREST diverges per target (pins #160). + + Given: + A single-row peak and three candidate genes on chr1. + When: + A correlated ``NEAREST(..., k := 1)`` query projects ``b.*`` on every + target — DuckDB emits the LATERAL form (``genes.* + distance``) while + DataFusion's fallback additionally surfaces the reserved + ``__giql_x_rk_*`` / ``__giql_x_rn`` columns on ``b``. + Then: + The cross-target row sets should NOT agree (DataFusion's rows carry the + extra reserved columns), so the oracle's identity assertion fails. This + xfail pins the known divergence so it is not silently forgotten; it + flips to a passing identity test when #160 hides the reserved columns. + """ + # Arrange / Act & Assert (xfail: identity assertion is expected to fail) + cross_target_oracle( + "SELECT b.* " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ], + expected=[("chr1", 280, 290, 0)], + engines={"generic": "duckdb"}, + ) + + +#: A chrom/start/end/strand schema for the stranded NEAREST oracle cases (the +#: default oracle schema carries no strand column). +_STRANDED_COLUMNS = ( + ("chrom", "utf8"), + ("start", "int64"), + ("end", "int64"), + ("strand", "utf8"), +) class TestCrossTargetOracleIntersectsAnyAll: diff --git a/tests/test_nearest_transpilation.py b/tests/test_nearest_transpilation.py index 2488cb0..9caac74 100644 --- a/tests/test_nearest_transpilation.py +++ b/tests/test_nearest_transpilation.py @@ -7,26 +7,48 @@ import pytest from sqlglot import parse_one +import giql.expanders # noqa: F401 (side-effect: registers the NEAREST expander) from giql import Table from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import ExpandOperators from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import DataFusionTarget +from giql.targets import GenericTarget + + +def _generate_for_target(sql: str, tables: Tables, target) -> str: + """Parse, run passes 1-3 against *target*, then generate SQL. + + Drives the expander for a specific :class:`~giql.targets.Target` so a + capability-dependent shape (e.g. DataFusion's decorrelated window fallback, + chosen because ``supports_lateral`` is False) can be asserted without an + engine. + """ + ast = parse_one(sql, dialect=GIQLDialect) + ast = resolve_operator_refs(ast, tables) + ast = canonicalize_coordinates(ast) + ast = ExpandOperators(target, tables).transform(ast) + return BaseGIQLGenerator(tables=tables).generate(ast) def _generate(sql: str, tables: Tables) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. + """Parse, run normalization passes 1-3, then generate SQL. - Operator resolution and coordinate canonicalization moved out of the emitter - and into the ResolveOperatorRefs / CanonicalizeCoordinates passes (epic #114, - issues #118-#123). Emitter-level tests must run both passes before generating, - exactly as :func:`giql.transpile.transpile` does, rather than calling - ``generate`` on a bare parsed AST. + Operator resolution, coordinate canonicalization, and operator expansion + moved out of the emitter into the ResolveOperatorRefs / CanonicalizeCoordinates + / ExpandOperators passes (epics #114, #137). NEAREST is now produced by its + registered expander (issue #142) rather than a ``giqlnearest_sql`` emitter, so + these tests must run pass 3 before generating, exactly as + :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a + bare parsed AST. """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) @@ -160,3 +182,234 @@ def test_nearest_with_signed(self, tables_with_peaks_and_genes): assert "ELSE -(" in output, ( f"Expected signed distance with negation for upstream, got:\n{output}" ) + + +class TestNearestDataFusionFallbackShape: + """Engine-free transpile-shape checks for the DataFusion window fallback (A8). + + A correlated NEAREST on the DataFusion target (``supports_lateral`` is False) + expands to the decorrelated window-function form. These assert its structural + invariants without running an engine: the window is present, the top-k filter + is a ``<= k`` predicate, no correlated ``LATERAL`` survives, and the candidate + cross-join and the window live at separate query levels. + """ + + def test_fallback_emits_window_with_topk_and_no_lateral( + self, tables_with_peaks_and_genes + ): + """Test the DataFusion fallback emits a windowed top-k with no LATERAL. + + Given: + A correlated NEAREST(genes, k := 1) on the DataFusion target. + When: + Transpiling. + Then: + It should emit a ROW_NUMBER() window, a `<= 1` top-k predicate, no + surviving LATERAL, and the cross-join and window at separate query + levels. + """ + # Arrange + sql = ( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 1) AS b" + ) + + # Act + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + # Assert + assert "ROW_NUMBER(" in output.upper() + assert "OVER (" in output.upper() + assert "<= 1" in output + assert "LATERAL" not in output.upper() + # The candidate cross-join sits one level below the window: the window's + # FROM is a parenthesized subquery, so a CROSS JOIN appears nested inside. + assert "CROSS JOIN" in output.upper() + + def test_fallback_stranded_emits_window_and_strand_match( + self, tables_with_peaks_and_genes + ): + """Test the stranded DataFusion fallback keeps a strand match, no LATERAL. + + Given: + A stranded correlated NEAREST on the DataFusion target. + When: + Transpiling. + Then: + It should emit the window form, keep a strand equality in the + candidate WHERE, and surface no LATERAL. + """ + # Arrange + sql = ( + "SELECT * FROM peaks CROSS JOIN LATERAL " + "NEAREST(genes, reference := peaks.interval, k := 1, stranded := true) AS b" + ) + + # Act + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + # Assert + assert "ROW_NUMBER(" in output.upper() + assert "LATERAL" not in output.upper() + assert 'peaks."strand"' in output + assert 'genes."strand"' in output + + def test_fallback_k_greater_than_one_uses_k_in_topk_predicate( + self, tables_with_peaks_and_genes + ): + """Test the DataFusion fallback carries the requested k in its top-k filter. + + Given: + A correlated NEAREST(genes, k := 3) on the DataFusion target. + When: + Transpiling. + Then: + The top-k predicate should carry the requested k (`<= 3`) rather than a + LIMIT, and no LATERAL should survive. + """ + # Arrange + sql = ( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 3) AS b" + ) + + # Act + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + # Assert + assert "ROW_NUMBER(" in output.upper() + assert "<= 3" in output + assert "LATERAL" not in output.upper() + + +class TestNearestUnaliasedCorrelatedFallback: + """The fallback synthesizes a LATERAL alias when the user omits one (B3).""" + + def test_expand_nearest_should_synthesize_alias_when_correlated_lateral_unaliased_on_nonlateral_target( # noqa: E501 + self, tables_with_peaks_and_genes + ): + """Test an unaliased correlated NEAREST transpiles on a non-LATERAL target. + + Given: + A correlated NEAREST whose surrounding CROSS JOIN LATERAL carries no + table alias — legitimate GIQL that transpiles fine on lateral-capable + engines — targeted at DataFusion (``supports_lateral`` is False), which + takes the decorrelated window fallback. + When: + Transpiling for the DataFusion target. + Then: + It should not raise: the fallback synthesizes an alias via + ``ctx.alias()`` instead of asserting one is present, and emits the + decorrelated window form (a synthesized ``__giql_x_`` alias, no + LATERAL). + """ + # Arrange + sql = ( + "SELECT a.start FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1)" + ) + + # Act + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + # Assert + assert "ROW_NUMBER(" in output.upper() + assert "LATERAL" not in output.upper() + assert "__giql_x_" in output + + def test_expand_nearest_should_not_assert_on_unaliased_lateral_under_O(self): + """Test the unaliased correlated fallback survives ``python -O``. + + Given: + A fresh ``python -O`` interpreter (asserts stripped), in which an + asserted alias precondition would degrade to a ``NoneType`` deref + rather than a clear error. + When: + Transpiling an unaliased correlated NEAREST for the DataFusion dialect. + Then: + It should transpile without raising and emit the window fallback, + proving the alias is synthesized (not asserted) so the optimized + interpreter cannot strip the guard into an opaque crash. + """ + # Arrange + import subprocess + import sys + + code = ( + "from giql import transpile, Table; " + "sql = transpile(" + "'SELECT a.start FROM peaks a CROSS JOIN LATERAL " + "NEAREST(genes, reference := a.interval, k := 1)', " + "tables=[Table('peaks'), Table('genes')], dialect='datafusion'); " + "assert 'ROW_NUMBER(' in sql.upper(), sql; " + "assert 'LATERAL' not in sql.upper(), sql; " + "print('ok')" + ) + + # Act + result = subprocess.run( + [sys.executable, "-O", "-c", code], + capture_output=True, + text=True, + ) + + # Assert + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "ok" + + +class TestNearestFallbackDetachContract: + """The fallback detaches its NEAREST node so the pass's replace is a no-op (A10).""" + + def test_fallback_detaches_node_and_rewritten_join_survives_pass( + self, tables_with_peaks_and_genes + ): + """Test the fallback detaches the NEAREST subtree and leaves the rewritten join. + + Given: + A correlated NEAREST on the DataFusion target, captured before the + ExpandOperators pass runs. + When: + Running the pass (which dispatches to the decorrelated fallback). + Then: + The original NEAREST node should be detached from the returned tree — + its surrounding LATERAL is swapped out, so it no longer reaches the + result root, making the pass's own ``node.replace`` a no-op — and the + rewritten plain JOIN (no LATERAL, ROW_NUMBER window present) should + survive in the returned tree. + """ + # Arrange + from giql.expressions import GIQLNearest + + sql = ( + "SELECT a.start FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) AS b" + ) + ast = parse_one(sql, dialect=GIQLDialect) + ast = resolve_operator_refs(ast, tables_with_peaks_and_genes) + ast = canonicalize_coordinates(ast) + nearest = ast.find(GIQLNearest) + assert nearest is not None and nearest.root() is ast + + # Act + result = ExpandOperators( + DataFusionTarget(), tables_with_peaks_and_genes + ).transform(ast) + + # Assert + # The fallback swaps out the LATERAL holding the NEAREST, so the node's + # subtree is detached: it no longer reaches the result root, which is what + # makes the pass's ``node.replace`` a no-op. + assert nearest.root() is not result + assert not list(result.find_all(GIQLNearest)) + output = BaseGIQLGenerator(tables=tables_with_peaks_and_genes).generate(result) + assert "LATERAL" not in output.upper() + assert "ROW_NUMBER(" in output.upper()