diff --git a/src/giql/expanders/intersects.py b/src/giql/expanders/intersects.py new file mode 100644 index 0000000..cbda7d6 --- /dev/null +++ b/src/giql/expanders/intersects.py @@ -0,0 +1,244 @@ +"""Generic expanders for the spatial predicates and set predicates (epic #137). + +Migrates INTERSECTS / CONTAINS / WITHIN and the ``ANY`` / ``ALL`` set predicates +off the legacy ``*_sql`` emitters on :class:`giql.generators.base.BaseGIQLGenerator` +and onto the operator-expander registry. Each expander turns one predicate node +into standard sqlglot AST built from the pass-1 :class:`~giql.resolver.ResolvedColumn` +metadata (already canonicalized to 0-based half-open by pass 2), so the emitted SQL +is byte-identical to the strings the legacy emitter produced. + +These are *node-local* predicate rewrites: an INTERSECTS / CONTAINS / WITHIN node +expands to a boolean ``(chrom = ... AND start < ... AND end > ...)`` expression that +replaces it in place. The whole-query column-to-column **join** rewrites (the binned +equi-join and the DuckDB IEJoin) remain capability-gated pre-pass transformers in +:mod:`giql.transformer` keyed on ``capabilities.range_join_strategy`` — they consume +a column-to-column INTERSECTS *join* before this pass runs, so by the time a +column-to-column INTERSECTS reaches an expander it is a residual predicate (e.g. +inside an ``OR``, or a join shape the transformer declined) that the legacy emitter +also rendered as a plain predicate. The expander handles that residual the same way. + +Only :class:`~giql.targets.GenericTarget` expanders are registered: spatial-predicate +*emission* is portable SQL-92 and does not vary by engine, so one generic expander +covers every target via the registry's ``(generic, op)`` fallback. +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot import maybe_parse + +from giql.dialect import GIQLDialect +from giql.expander import ExpansionContext +from giql.expander import register +from giql.expressions import Contains +from giql.expressions import Intersects +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within +from giql.range_parser import ParsedRange +from giql.range_parser import RangeParser +from giql.resolver import ResolvedColumn +from giql.targets import GenericTarget + + +def _fragment(fragment: str) -> exp.Expression: + """Parse a resolved SQL fragment (e.g. ``a."end"`` / ``'chr1'``) into AST. + + The pass-1 :class:`~giql.resolver.ResolvedColumn` carries column references as + pre-canonicalized SQL string fragments; parse them through the GIQL dialect so + the rebuilt predicate reserializes identically to the legacy emitter's string. + """ + parsed = maybe_parse(fragment, dialect=GIQLDialect) + if parsed is None: + # maybe_parse returns None only for an empty/None input; a ResolvedColumn + # fragment is never empty, so this is an internal invariant violation. + raise ValueError(f"Could not parse resolved column fragment: {fragment!r}") + return parsed + + +def _predicate_column(ctx: ExpansionContext, arg: str) -> ResolvedColumn: + """Return the :class:`ResolvedColumn` for predicate operand *arg*. + + Mirrors :meth:`giql.generators.base.BaseGIQLGenerator._predicate_operand`: the + expander consumes only the pass-1 resolution; a missing column means pass 1 did + not run (an internal invariant violation), so raise the historical message. + """ + resolution = ctx.resolution + if resolution is not None: + resolved = resolution.column(arg) + if resolved is not None: + return resolved + raise ValueError( + f"Spatial predicate operand {arg!r} was not resolved; run the " + "ResolveOperatorRefs pass (transpile pipeline) before generation." + ) + + +def _range_predicate( + column: ResolvedColumn, parsed: ParsedRange, op_type: str +) -> exp.Expression: + """Build the boolean AST for ``column ``. + + Reproduces :meth:`BaseGIQLGenerator._generate_range_predicate` as AST. The + column fragments are already canonical 0-based half-open (pass 2); the parsed + range is canonicalized by the caller. Returns a parenthesized boolean. + """ + chrom = _fragment(column.chrom) + start = _fragment(column.start) + end = _fragment(column.end) + chrom_lit = exp.Literal.string(parsed.chromosome) + r_start = exp.Literal.number(parsed.start) + r_end = exp.Literal.number(parsed.end) + + if op_type == "intersects": + # Ranges overlap if: start1 < end2 AND end1 > start2 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.LT(this=start, expression=r_end), + exp.GT(this=end, expression=r_start), + ) + elif op_type == "contains": + if parsed.end == parsed.start + 1: + # Point query: start1 <= point < end1 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.LTE(this=start, expression=r_start), + exp.GT(this=end, expression=r_start), + ) + else: + # Range query: start1 <= start2 AND end1 >= end2 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.LTE(this=start, expression=r_start), + exp.GTE(this=end, expression=r_end), + ) + elif op_type == "within": + # left within right: start1 >= start2 AND end1 <= end2 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.GTE(this=start, expression=r_start), + exp.LTE(this=end, expression=r_end), + ) + else: + raise ValueError(f"Unknown spatial op_type: {op_type!r}") + + return exp.paren(cond) + + +def _column_join( + left: ResolvedColumn, right: ResolvedColumn, op_type: str +) -> exp.Expression: + """Build the boolean AST for a column-to-column spatial predicate. + + Reproduces :meth:`BaseGIQLGenerator._generate_column_join` as AST. Both + operands' fragments are pre-canonicalized (pass 2). Returns a parenthesized + boolean. + """ + l_chrom, r_chrom = _fragment(left.chrom), _fragment(right.chrom) + l_start, r_start = _fragment(left.start), _fragment(right.start) + l_end, r_end = _fragment(left.end), _fragment(right.end) + + if op_type == "intersects": + cond = exp.and_( + exp.EQ(this=l_chrom, expression=r_chrom), + exp.LT(this=l_start, expression=r_end), + exp.GT(this=l_end, expression=r_start), + ) + elif op_type == "contains": + cond = exp.and_( + exp.EQ(this=l_chrom, expression=r_chrom), + exp.LTE(this=l_start, expression=r_start), + exp.GTE(this=l_end, expression=r_end), + ) + elif op_type == "within": + cond = exp.and_( + exp.EQ(this=l_chrom, expression=r_chrom), + exp.GTE(this=l_start, expression=r_start), + exp.LTE(this=l_end, expression=r_end), + ) + else: + raise ValueError(f"Unknown spatial op_type: {op_type!r}") + + return exp.paren(cond) + + +def _expand_spatial_op( + node: exp.Expression, ctx: ExpansionContext, op_type: str +) -> exp.Expression: + """Expand one INTERSECTS / CONTAINS / WITHIN node to a boolean predicate. + + Dispatches on the right operand exactly as the legacy emitter did: the + presence of a resolved right *column* — keyed off + ``ctx.resolution.column("expression")``, the slot pass 1 attaches a + :class:`ResolvedColumn` to when the right operand is a column reference — + selects the column-to-column path; its absence means the right operand is a + literal range, parsed in place. + """ + resolution = ctx.resolution + right_column = resolution.column("expression") if resolution is not None else None + left = _predicate_column(ctx, "this") + + if right_column is not None: + return _column_join(left, right_column, op_type) + + # Literal range string (e.g. interval INTERSECTS 'chr1:1000-2000'). Reproduce + # the legacy emitter's parse-and-wrap-error behavior verbatim: any parse + # failure (including the RangeParser's own ValueError) is wrapped in the + # historical "Could not parse genomic range" message. + right_expr = node.args.get("expression") + raw = right_expr.sql(dialect=GIQLDialect) if right_expr is not None else "" + try: + range_str = raw.strip("'\"") + parsed = RangeParser.parse(range_str).to_zero_based_half_open() + return _range_predicate(left, parsed, op_type) + except Exception as e: + raise ValueError(f"Could not parse genomic range: {raw}. Error: {e}") from e + + +@register(GenericTarget, Intersects) +def expand_intersects(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand an INTERSECTS predicate to standard boolean SQL AST.""" + return _expand_spatial_op(node, ctx, "intersects") + + +@register(GenericTarget, Contains) +def expand_contains(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a CONTAINS predicate to standard boolean SQL AST.""" + return _expand_spatial_op(node, ctx, "contains") + + +@register(GenericTarget, Within) +def expand_within(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a WITHIN predicate to standard boolean SQL AST.""" + return _expand_spatial_op(node, ctx, "within") + + +@register(GenericTarget, SpatialSetPredicate) +def expand_spatial_set( + node: exp.Expression, ctx: ExpansionContext +) -> exp.Expression: + """Expand a quantified set predicate (``ANY`` / ``ALL``) to boolean SQL AST. + + Reproduces :meth:`BaseGIQLGenerator._generate_spatial_set`: the single left + column is compared against every literal range, and the per-range conditions + are OR-combined for ``ANY`` / AND-combined for ``ALL``, all wrapped in one + outer paren. + """ + operator = node.args["operator"] + quantifier = node.args["quantifier"] + ranges = node.args["ranges"] + + column = _predicate_column(ctx, "this") + op_type = operator.lower() + + conditions: list[exp.Expression] = [] + for range_expr in ranges: + range_str = range_expr.sql(dialect=GIQLDialect).strip("'\"") + parsed = RangeParser.parse(range_str).to_zero_based_half_open() + conditions.append(_range_predicate(column, parsed, op_type)) + + if quantifier.upper() == "ANY": + combined = exp.or_(*conditions) + else: + combined = exp.and_(*conditions) + + return exp.paren(combined) diff --git a/src/giql/expressions.py b/src/giql/expressions.py index 0c6bc80..4949fdd 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -114,8 +114,9 @@ class SpatialPredicate(exp.Binary): #: per-operator ``GIQL_EXPAND`` flag mirrors ``GIQL_CANONICALIZE``: an operator #: takes the new AST-expansion path only when it sets ``GIQL_EXPAND = True`` *and* #: an expander is registered for it; otherwise the legacy ``*_sql`` emitter runs. -#: Every operator defaults to ``False`` here, so the pass is a strict no-op until a -#: later migration step (#140+) flips one operator's flag alongside its expander. +#: This is the opt-out default: an operator inherits it and stays on the legacy +#: emitter until its migration step flips the flag to ``True`` alongside its +#: registered expander. Operators already migrated override it on their own class. _EXPAND = False @@ -126,7 +127,12 @@ class Intersects(SpatialPredicate): """ GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141). A literal-range or + #: residual column-to-column INTERSECTS *predicate* expands through + #: ``giql.expanders.intersects``; a column-to-column INTERSECTS *join* is + #: consumed by the capability-gated binned / IEJoin pre-pass transformers + #: before this pass runs, so the predicate expander never sees it. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), @@ -141,7 +147,9 @@ class Contains(SpatialPredicate): """ GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141); expands through + #: ``giql.expanders.intersects``. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), @@ -156,7 +164,9 @@ class Within(SpatialPredicate): """ GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141); expands through + #: ``giql.expanders.intersects``. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), @@ -180,7 +190,9 @@ class SpatialSetPredicate(exp.Expression): } GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141); expands through + #: ``giql.expanders.intersects``. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index a4be873..cc494dc 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -3,13 +3,8 @@ from giql.canonical import decanonical_end from giql.canonical import decanonical_start -from giql.expressions import Contains from giql.expressions import GIQLDisjoin from giql.expressions import GIQLNearest -from giql.expressions import Intersects -from giql.expressions import SpatialSetPredicate -from giql.expressions import Within -from giql.range_parser import ParsedRange from giql.range_parser import RangeParser from giql.resolver import META_KEY from giql.resolver import OperatorResolution @@ -41,45 +36,12 @@ def __init__(self, tables: Tables | None = None, **kwargs): super().__init__(**kwargs) self.tables = tables or Tables() - def intersects_sql(self, expression: Intersects) -> str: - """Generate standard SQL for INTERSECTS. - - :param expression: - INTERSECTS expression node - :return: - SQL predicate string - """ - return self._generate_spatial_op(expression, "intersects") - - def contains_sql(self, expression: Contains) -> str: - """Generate standard SQL for CONTAINS. - - :param expression: - CONTAINS expression node - :return: - SQL predicate string - """ - return self._generate_spatial_op(expression, "contains") - - def within_sql(self, expression: Within) -> str: - """Generate standard SQL for WITHIN. - - :param expression: - WITHIN expression node - :return: - SQL predicate string - """ - return self._generate_spatial_op(expression, "within") - - def spatialsetpredicate_sql(self, expression: SpatialSetPredicate) -> str: - """Generate SQL for spatial set predicates (ANY/ALL). - - :param expression: - SpatialSetPredicate expression node - :return: - SQL predicate string - """ - return self._generate_spatial_set(expression) + # INTERSECTS / CONTAINS / WITHIN and the ANY/ALL set predicates are migrated + # to the ExpandOperators registry (#141): they expand to standard boolean AST + # in ``giql.expanders.intersects`` before generation, so the generator no + # longer carries ``intersects_sql`` / ``contains_sql`` / ``within_sql`` / + # ``spatialsetpredicate_sql`` emitters or their ``_generate_spatial_*`` / + # ``_predicate_operand`` helpers. def giqlnearest_sql(self, expression: GIQLNearest) -> str: """Generate SQL for NEAREST function. @@ -601,217 +563,6 @@ def _generate_distance_case( f"ELSE ({start_a} - {end_b} + 1) END END" ) - def _predicate_operand(self, expression: exp.Expression, arg: str) -> ResolvedColumn: - """Return the :class:`ResolvedColumn` for a spatial predicate operand. - - Reads the column resolution attached to *expression* by the - ``ResolveOperatorRefs`` pass (pass 1). The emitter consumes only the - resolved metadata; all name/column resolution lives in the pass. - - :param expression: - The spatial predicate node carrying the resolution metadata. - :param arg: - The operand slot key (``"this"`` or ``"expression"``). - :return: - The resolved column metadata. - """ - resolution = expression.meta.get(META_KEY) - if isinstance(resolution, OperatorResolution): - resolved = resolution.column(arg) - if resolved is not None: - return resolved - - raise ValueError( - f"Spatial predicate operand {arg!r} was not resolved; run the " - "ResolveOperatorRefs pass (transpile pipeline) before generation." - ) - - def _generate_spatial_op(self, expression: exp.Binary, op_type: str) -> str: - """Generate SQL for a spatial operation. - - :param expression: - AST node (Intersects, Contains, or Within) - :param op_type: - 'intersects', 'contains', or 'within' - :return: - SQL predicate string - """ - right_raw = self.sql(expression, "expression") - - # Check if right side is a column reference or a literal range string - if "." in right_raw and not right_raw.startswith("'"): - # Column-to-column join (e.g., a.interval INTERSECTS b.interval) - left = self._predicate_operand(expression, "this") - right = self._predicate_operand(expression, "expression") - return self._generate_column_join(left, right, op_type) - else: - # Literal range string (e.g., interval INTERSECTS 'chr1:1000-2000') - try: - range_str = right_raw.strip("'\"") - parsed_range = RangeParser.parse(range_str).to_zero_based_half_open() - left = self._predicate_operand(expression, "this") - return self._generate_range_predicate(left, parsed_range, op_type) - except Exception as e: - raise ValueError( - f"Could not parse genomic range: {right_raw}. Error: {e}" - ) - - def _generate_range_predicate( - self, - column: ResolvedColumn, - parsed_range: ParsedRange, - op_type: str, - ) -> str: - """Generate SQL predicate for a range operation. - - :param column: - Resolved column operand (physical chrom/start/end fragments plus the - backing :class:`~giql.table.Table` config for canonicalization). - :param parsed_range: - Parsed genomic range - :param op_type: - 'intersects', 'contains', or 'within' - :return: - SQL predicate string - """ - # The alias-qualified column fragments come pre-resolved on the - # ResolvedColumn, already canonicalized to 0-based half-open by - # CanonicalizeCoordinates (pass 2, issue #123). The predicate returns a - # boolean, which is encoding-invariant, so no output de-canonicalization - # is needed. - chrom_col = column.chrom - start_col = column.start - end_col = column.end - - chrom = parsed_range.chromosome - start = parsed_range.start - end = parsed_range.end - - if op_type == "intersects": - # Ranges overlap if: start1 < end2 AND end1 > start2 - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} < {end} " - f"AND {end_col} > {start})" - ) - - elif op_type == "contains": - # Point query: start1 <= point < end1 - if end == start + 1: - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} <= {start} " - f"AND {end_col} > {start})" - ) - # Range query: start1 <= start2 AND end1 >= end2 - else: - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} <= {start} " - f"AND {end_col} >= {end})" - ) - - elif op_type == "within": - # Left within right: start1 >= start2 AND end1 <= end2 - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} >= {start} " - f"AND {end_col} <= {end})" - ) - - raise ValueError(f"Unknown operation: {op_type}") - - def _generate_column_join( - self, left: ResolvedColumn, right: ResolvedColumn, op_type: str - ) -> str: - """Generate SQL for column-to-column spatial joins. - - :param left: - Resolved left column operand (e.g., for 'a.interval'). - :param right: - Resolved right column operand (e.g., for 'b.interval'). - :param op_type: - 'intersects', 'contains', or 'within' - :return: - SQL predicate string - """ - # The alias-qualified chrom/start/end fragments come pre-resolved on the - # ResolvedColumns, already canonicalized to 0-based half-open by - # CanonicalizeCoordinates (pass 2, issue #123). The predicate returns a - # boolean (encoding-invariant), so no output de-canonicalization is needed. - l_chrom = left.chrom - r_chrom = right.chrom - l_start = left.start - l_end = left.end - r_start = right.start - r_end = right.end - - if op_type == "intersects": - # Ranges overlap if: chrom1 = chrom2 AND start1 < end2 AND end1 > start2 - return ( - f"({l_chrom} = {r_chrom} " - f"AND {l_start} < {r_end} " - f"AND {l_end} > {r_start})" - ) - - elif op_type == "contains": - # Left contains right: chrom1 = chrom2 AND start1 <= start2 AND end1 >= end2 - return ( - f"({l_chrom} = {r_chrom} " - f"AND {l_start} <= {r_start} " - f"AND {l_end} >= {r_end})" - ) - - elif op_type == "within": - # Left within right: chrom1 = chrom2 AND start1 >= start2 AND end1 <= end2 - return ( - f"({l_chrom} = {r_chrom} " - f"AND {l_start} >= {r_start} " - f"AND {l_end} <= {r_end})" - ) - - raise ValueError(f"Unknown operation: {op_type}") - - def _generate_spatial_set(self, expression: SpatialSetPredicate) -> str: - """Generate SQL for spatial set predicates (ANY/ALL). - - Examples: - column INTERSECTS ANY(...) -> (condition1 OR condition2 OR ...) - column INTERSECTS ALL(...) -> (condition1 AND condition2 AND ...) - - :param expression: - SpatialSetPredicate expression node - :return: - SQL predicate string - """ - operator = expression.args["operator"] - quantifier = expression.args["quantifier"] - ranges = expression.args["ranges"] - - # Resolve the (single) left column operand once; every range condition - # compares against the same column. The set predicate's ranges are - # always literals, so only this operand needs resolution. - column = self._predicate_operand(expression, "this") - - # Parse all ranges - parsed_ranges = [] - for range_expr in ranges: - range_str = self.sql(range_expr).strip("'\"") - parsed_range = RangeParser.parse(range_str).to_zero_based_half_open() - parsed_ranges.append(parsed_range) - - op_type = operator.lower() - - # Generate conditions for each range - conditions = [] - for parsed_range in parsed_ranges: - condition = self._generate_range_predicate(column, parsed_range, op_type) - conditions.append(condition) - - # Combine with AND (for ALL) or OR (for ANY) - combinator = " OR " if quantifier.upper() == "ANY" else " AND " - return "(" + combinator.join(conditions) + ")" - def _detect_nearest_mode( self, expression: GIQLNearest, parent_expression: exp.Expression | None = None ) -> str: diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 81ee8e1..b5568e2 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -14,7 +14,9 @@ import giql.expanders # noqa: F401 from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import REGISTRY from giql.expander import ExpandOperators +from giql.expressions import Intersects from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Table @@ -151,41 +153,78 @@ def transpile( "of the binned equi-join. Pass one or the other, not both." ) + # The INTERSECTS join-rewrite override (governs the three uses below). + # + # A *target-specific* ``(target, Intersects)`` registry entry — the public + # extension hook, matched by ``ExpanderRegistry.has_override`` — takes over the + # INTERSECTS join rewrite entirely. ``has_override`` deliberately matches only + # an *exact non-generic* entry: the built-in ``(GenericTarget(), Intersects)`` + # predicate expander is NOT a join-strategy override (it only renders the + # residual / literal-range predicates the join transformers leave behind), so + # it must not disable the join rewrite. When an override is present, all three + # built-in join paths below are bypassed for that target so the INTERSECTS node + # flows untouched into ExpandOperators, which dispatches it to that expander: + # 1. ``intersects_bin_size`` is rejected — it only configures the built-in + # binned transformer the override supersedes (rejected here, parallel to + # the iejoin rejection above, rather than silently dropped); + # 2. the DuckDB IEJoin short-circuit is skipped (the registry-deferral the + # IEJoin early-return used to preclude, #141); + # 3. the binned-join transformer is skipped. + target_overrides_intersects = REGISTRY.has_override(target, Intersects) + if target_overrides_intersects and intersects_bin_size is not None: + raise ValueError( + "intersects_bin_size has no effect when a target-specific " + f"(target={target.name!r}, Intersects) expander is registered; that " + "expander supersedes the built-in binned join transformer the bin " + "size configures. Pass one or the other, not both." + ) + tables_container = _build_tables(tables) with _reraise_as_value_error("Parse error", query=giql): ast = parse_one(giql, dialect=GIQLDialect) + # The column-to-column INTERSECTS *join* rewrites are capability-gated + # pre-pass transformers (epic #137, #141): the target's + # ``range_join_strategy`` selects the DuckDB IEJoin plan or the generic + # binned equi-join. They run on the raw parsed AST (before resolution, which + # rewrites the genomic column name) and consume a column-to-column INTERSECTS + # *join* so it never reaches the predicate expander; a literal-range or + # residual column-to-column INTERSECTS *predicate* survives to pass 3. + # ``target_overrides_intersects`` (computed above) gates whether these + # built-in join paths run — see its definition for the override rationale. + # Falls back to the binned plan for unsupported shapes — see # IntersectsDuckDBIEJoinTransformer.transform_to_sql for the complete - # fallback set. - if uses_iejoin: + # fallback set. The IEJoin transformer emits a whole-query string, so when it + # produces output it must short-circuit the AST pipeline. This never skips an + # INTERSECTS that pass 3 would expand: ``_classify_extras`` forces the binned + # fallback (returning None here) for any query carrying a residual INTERSECTS + # alongside the join, so a query the IEJoin transformer accepts has no residual + # INTERSECTS left for the expander. (A residual CONTAINS/WITHIN/ANY beside an + # IEJoin is a pre-existing IEJoin limitation that errors identically on main.) + if uses_iejoin and not target_overrides_intersects: duckdb_transformer = IntersectsDuckDBIEJoinTransformer(tables_container) with _reraise_as_value_error("Transformation error"): duckdb_sql = duckdb_transformer.transform_to_sql(ast) if duckdb_sql is not None: - # WARNING: this early return emits the legacy IEJoin SQL directly and - # SKIPS the normalization pipeline below — pass 1 (resolution), pass 2 - # (canonicalization), and pass 3 (ExpandOperators, constructed ~40 - # lines down). The ExpandOperators registry is therefore NOT consulted - # on this path: a flagged operator on an IEJoin-eligible duckdb query - # is left un-expanded. This is benign today (the registry is empty and - # no operator opts in), but any DuckDB-pathed operator migration (#141) - # must either run expansion BEFORE this early return or have the IEJoin - # transformer defer to the registry. See the strict-xfail - # characterization test pinning this gap in tests/test_expander.py. return duckdb_sql - intersects_transformer = IntersectsBinnedJoinTransformer( - tables_container, - bin_size=intersects_bin_size, - ) merge_transformer = MergeTransformer(tables_container) cluster_transformer = ClusterTransformer(tables_container) generator = BaseGIQLGenerator(tables=tables_container) with _reraise_as_value_error("Transformation error"): - ast = intersects_transformer.transform(ast) + # Reaching here with an iejoin target means the IEJoin transformer + # declined the query (returned None) and fell back to the binned plan, + # exactly as before. ``intersects_bin_size`` is rejected up front for + # iejoin targets, so the binned transformer always sees its default there. + if not target_overrides_intersects: + intersects_transformer = IntersectsBinnedJoinTransformer( + tables_container, + bin_size=intersects_bin_size, + ) + ast = intersects_transformer.transform(ast) ast = merge_transformer.transform(ast) ast = cluster_transformer.transform(ast) @@ -196,22 +235,21 @@ def transpile( with _reraise_as_value_error("Resolution error"): ast = resolve_operator_refs(ast, tables_container) - # Pass 2 of the normalization pipeline (epic #114): for each operator that - # opts into GIQL_CANONICALIZE, rewrite its non-canonical interval operands — - # synthesizing canonical __giql_canon_* wrapper CTEs — so downstream passes - # and emitters see canonical 0-based half-open coordinates. + # Pass 2 of the normalization pipeline (epic #114): synthesize canonical + # __giql_canon_* wrapper CTEs for non-canonical interval operands of operators + # that opt in via GIQL_CANONICALIZE; those operators are rewritten here, and + # operators that do not opt in are left untouched. with _reraise_as_value_error("Canonicalization error"): ast = canonicalize_coordinates(ast) - # Pass 3 of the normalization pipeline (epic #137): replace each opted-in - # GIQL operator node with the AST its registered expander produces for the - # active target. Each operator that opts in (GIQL_EXPAND) with a registered - # expander is rewritten here; any operator that is unflagged or has no - # registered expander falls through to its legacy *_sql emitter on the - # generator. - expand_operators = ExpandOperators(target, tables_container) + # Pass 3 of the normalization pipeline (epic #137): replace each GIQL operator + # node that opts in (GIQL_EXPAND) and resolves a registered expander with the + # AST that expander produces for the active target. Operators that are + # unflagged or resolve no expander are left untouched and the generator renders + # them via their legacy ``*_sql`` emitter as before. + expand_pass = ExpandOperators(target, tables_container) with _reraise_as_value_error("Expansion error"): - ast = expand_operators.transform(ast) + ast = expand_pass.transform(ast) with _reraise_as_value_error("Transpilation error"): sql = generator.generate(ast) diff --git a/tests/expanders/__init__.py b/tests/expanders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/expanders/test_intersects.py b/tests/expanders/test_intersects.py new file mode 100644 index 0000000..aff4e86 --- /dev/null +++ b/tests/expanders/test_intersects.py @@ -0,0 +1,442 @@ +"""Direct unit tests for the spatial / set predicate expanders (#141). + +These call ``expand_intersects`` / ``expand_contains`` / ``expand_within`` / +``expand_spatial_set`` directly with a hand-built :class:`ExpansionContext`, +characterizing each dispatch branch (column-to-column vs literal range; CONTAINS +point vs range; ANY/OR vs ALL/AND) and pinning the chosen error messages on +invalid input. They sit outside ``tests/test_expander.py`` so they do not touch +that file's shared, operator-agnostic fixture/infra region. +""" + +import pytest +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expander import REGISTRY +from giql.expander import ExpansionContext +from giql.expanders.intersects import expand_contains +from giql.expanders.intersects import expand_intersects +from giql.expanders.intersects import expand_spatial_set +from giql.expanders.intersects import expand_within +from giql.expressions import Contains +from giql.expressions import Intersects +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within +from giql.resolver import OperatorResolution +from giql.resolver import ResolvedColumn +from giql.table import Tables +from giql.targets import DataFusionTarget +from giql.targets import GenericTarget +from giql.transpile import transpile + +_LEFT = ResolvedColumn( + chrom='a."chrom"', start='a."start"', end='a."end"', strand=None, table=None +) +_RIGHT = ResolvedColumn( + chrom='b."chrom"', start='b."start"', end='b."end"', strand=None, table=None +) +_OPERATOR_TYPES = (Intersects, Contains, Within, SpatialSetPredicate) + + +def _context(query: str, columns: dict[str, ResolvedColumn]) -> tuple: + """Find the spatial operator in *query* and build a context with *columns*.""" + root = parse_one(query, dialect=GIQLDialect) + node = next(n for n in root.walk() if isinstance(n, _OPERATOR_TYPES)) + resolution = OperatorResolution( + operator=type(node).__name__, slots={}, columns=columns + ) + ctx = ExpansionContext(node, resolution, GenericTarget(), Tables()) + return node, ctx + + +def _sql(expression: exp.Expression) -> str: + """Serialize a built expression through the GIQL dialect.""" + return expression.sql(dialect=GIQLDialect) + + +class TestSpatialExpanders: + """Direct expansion of the spatial / set predicate expanders (#141).""" + + def test_expand_intersects_should_build_overlap_predicate_when_literal_range(self): + """Test that a literal-range INTERSECTS expands to the overlap predicate. + + Given: + An INTERSECTS node whose right operand is a literal range and a + context resolving only the left column. + When: + Expanding it. + Then: + It should build the overlap boolean (chrom = lit AND start < end2 AND + end > start2) with the right operand as numeric literals. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS 'chr1:1000-2000'", + {"this": _LEFT}, + ) + + # Act + result = expand_intersects(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" < 2000 AND a."end" > 1000)' + ) + + def test_expand_intersects_should_build_join_predicate_when_column_to_column(self): + """Test that a column-to-column INTERSECTS expands to a join predicate. + + Given: + An INTERSECTS node with a resolved right *column* (the dispatch keys on + ctx.resolution.column("expression")). + When: + Expanding it. + Then: + It should compare the two columns' endpoints rather than literals. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", + {"this": _LEFT, "expression": _RIGHT}, + ) + + # Act + result = expand_intersects(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = b."chrom" AND a."start" < b."end" AND a."end" > b."start")' + ) + + def test_expand_contains_should_build_point_predicate_when_single_base(self): + """Test that a single-base CONTAINS expands to the point-containment form. + + Given: + A CONTAINS node whose literal range is a single base (end == start+1). + When: + Expanding it. + Then: + It should use the point form (start <= point AND end > point), not the + range form. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval CONTAINS 'chr1:1000'", {"this": _LEFT} + ) + + # Act + result = expand_contains(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" <= 1000 AND a."end" > 1000)' + ) + + def test_expand_contains_should_build_range_predicate_when_multi_base(self): + """Test that a multi-base CONTAINS expands to the range-containment form. + + Given: + A CONTAINS node whose literal range spans more than one base. + When: + Expanding it. + Then: + It should use the range form (start <= start2 AND end >= end2). + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval CONTAINS 'chr1:1000-2000'", + {"this": _LEFT}, + ) + + # Act + result = expand_contains(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" <= 1000 AND a."end" >= 2000)' + ) + + def test_expand_contains_should_build_join_predicate_when_column_to_column(self): + """Test that a column-to-column CONTAINS expands to the containment join. + + Given: + A CONTAINS node with a resolved right *column* (the dispatch keys on + ctx.resolution.column("expression")). + When: + Expanding it. + Then: + It should build the left-contains-right predicate (start1 <= start2 + AND end1 >= end2) comparing the two columns' endpoints. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a JOIN b ON a.interval CONTAINS b.interval", + {"this": _LEFT, "expression": _RIGHT}, + ) + + # Act + result = expand_contains(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = b."chrom" AND a."start" <= b."start" AND a."end" >= b."end")' + ) + + def test_expand_within_should_build_containment_predicate_when_literal_range(self): + """Test that WITHIN expands to the left-within-right containment form. + + Given: + A WITHIN node with a literal range. + When: + Expanding it. + Then: + It should build start >= start2 AND end <= end2. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval WITHIN 'chr1:1000-2000'", {"this": _LEFT} + ) + + # Act + result = expand_within(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" >= 1000 AND a."end" <= 2000)' + ) + + def test_expand_within_should_build_join_predicate_when_column_to_column(self): + """Test that a column-to-column WITHIN expands to the within join predicate. + + Given: + A WITHIN node with a resolved right *column* (the dispatch keys on + ctx.resolution.column("expression")). + When: + Expanding it. + Then: + It should build the left-within-right predicate (start1 >= start2 AND + end1 <= end2) comparing the two columns' endpoints. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a JOIN b ON a.interval WITHIN b.interval", + {"this": _LEFT, "expression": _RIGHT}, + ) + + # Act + result = expand_within(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = b."chrom" AND a."start" >= b."start" AND a."end" <= b."end")' + ) + + def test_expand_spatial_set_should_or_combine_conditions_when_any(self): + """Test that an ANY set predicate OR-combines its per-range conditions. + + Given: + An INTERSECTS ANY node over two literal ranges. + When: + Expanding it. + Then: + The two per-range overlap predicates should be OR-combined inside one + outer paren. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval " + "INTERSECTS ANY ('chr1:1-100', 'chr1:200-300')", + {"this": _LEFT}, + ) + + # Act + result = expand_spatial_set(node, ctx) + + # Assert + assert _sql(result) == ( + '((a."chrom" = \'chr1\' AND a."start" < 100 AND a."end" > 1) OR ' + '(a."chrom" = \'chr1\' AND a."start" < 300 AND a."end" > 200))' + ) + + def test_expand_spatial_set_should_and_combine_conditions_when_all(self): + """Test that an ALL set predicate AND-combines its per-range conditions. + + Given: + An INTERSECTS ALL node over two literal ranges. + When: + Expanding it. + Then: + The two per-range overlap predicates should be AND-combined inside one + outer paren. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval " + "INTERSECTS ALL ('chr1:1-100', 'chr1:200-300')", + {"this": _LEFT}, + ) + + # Act + result = expand_spatial_set(node, ctx) + + # Assert + assert _sql(result) == ( + '((a."chrom" = \'chr1\' AND a."start" < 100 AND a."end" > 1) AND ' + '(a."chrom" = \'chr1\' AND a."start" < 300 AND a."end" > 200))' + ) + + +@pytest.fixture +def isolated_registry(): + """Snapshot/restore the process REGISTRY so a test can register an override.""" + saved = REGISTRY.snapshot() + try: + yield REGISTRY + finally: + REGISTRY.restore(saved) + + +class TestBinnedTargetOverrideDeferral: + """A target-specific Intersects override defers the binned join rewrite (#141).""" + + def test_transpile_should_skip_binned_rewrite_when_target_override( + self, isolated_registry + ): + """Test that a (target, Intersects) override bypasses the binned transformer. + + Given: + A column-to-column INTERSECTS join on the generic binned path + (dialect='datafusion') with a (DataFusionTarget, Intersects) override + registered. + When: + Transpiling. + Then: + The override's sentinel reaches the SQL and no binned equi-join + artifact is emitted — the override takes over the join rewrite that the + built-in binned transformer would otherwise perform. + """ + # Arrange + isolated_registry.register( + DataFusionTarget(), + Intersects, + lambda n, c: exp.column("BINNED_OVERRIDE_SENTINEL"), + ) + query = ( + "SELECT a.start FROM peaks a " + "JOIN genes b ON a.interval INTERSECTS b.interval" + ) + + # Act + sql = transpile(query, tables=["peaks", "genes"], dialect="datafusion") + + # Assert + assert "BINNED_OVERRIDE_SENTINEL" in sql + assert "_bins" not in sql + + def test_transpile_should_reject_bin_size_when_target_override( + self, isolated_registry + ): + """Test that bin size is rejected under a binned-target Intersects override. + + Given: + A (DataFusionTarget, Intersects) override registered. + When: + Transpiling with intersects_bin_size set (which only configures the + built-in binned transformer the override supersedes). + Then: + transpile() raises ValueError rather than silently dropping the bin + size, parallel to the iejoin rejection. + """ + # Arrange + isolated_registry.register( + DataFusionTarget(), + Intersects, + lambda n, c: exp.column("BINNED_OVERRIDE_SENTINEL"), + ) + query = ( + "SELECT a.start FROM peaks a " + "JOIN genes b ON a.interval INTERSECTS b.interval" + ) + + # Act & assert + with pytest.raises(ValueError, match=r"intersects_bin_size has no effect"): + transpile( + query, + tables=["peaks", "genes"], + dialect="datafusion", + intersects_bin_size=5000, + ) + + +class TestSpatialExpanderErrors: + """Characterization tests pinning the chosen error messages on invalid input.""" + + def test_expand_intersects_should_wrap_parse_error_when_invalid_range(self): + """Test that an unparseable literal range raises the wrapped diagnostic. + + Given: + An INTERSECTS node whose literal range string cannot be parsed. + When: + Expanding it. + Then: + It should raise ValueError with the historical "Could not parse + genomic range" wrapper, chained from the underlying parse error. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS 'invalid'", {"this": _LEFT} + ) + + # Act & assert + with pytest.raises(ValueError, match=r"Could not parse genomic range") as exc: + expand_intersects(node, ctx) + assert exc.value.__cause__ is not None + + def test_expand_intersects_should_raise_invariant_when_left_unresolved(self): + """Test that a missing left-operand resolution raises the invariant error. + + Given: + An INTERSECTS node whose context resolved no "this" column (pass 1 did + not run). + When: + Expanding it. + Then: + It should raise ValueError naming the unresolved operand and pointing + at the ResolveOperatorRefs pass. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS 'chr1:1-100'", {} + ) + + # Act & assert + with pytest.raises( + ValueError, match=r"Spatial predicate operand 'this' was not resolved" + ): + expand_intersects(node, ctx) + + def test_expand_spatial_set_should_not_wrap_parse_error_when_invalid_range(self): + """Test that a set-predicate bad range surfaces the raw parser error. + + Given: + An INTERSECTS ANY node with one unparseable range. + When: + Expanding it. + Then: + The raw RangeParser ValueError propagates *unwrapped* — the set- + predicate path does NOT apply the "Could not parse genomic range" + wrapper the single-operand path does. This pins the current + (pre-existing) asymmetry so any future unification is a conscious + change, not an accident. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS ANY ('bad', 'chr1:1-2')", + {"this": _LEFT}, + ) + + # Act & assert + with pytest.raises(ValueError, match=r"Invalid genomic range format") as exc: + expand_spatial_set(node, ctx) + assert "Could not parse genomic range" not in str(exc.value) diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index a7141c4..9c8f8b0 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -11,7 +11,7 @@ from sqlglot import exp from sqlglot import parse_one -import giql # noqa: F401 (ensures the built-in expanders are registered) +import giql.expanders # noqa: F401 (registers built-in expanders) from giql import Table from giql import transpile from giql.canonicalizer import canonicalize_coordinates @@ -28,17 +28,16 @@ def _generate_through_passes(sql: str, tables: Tables) -> str: """Parse, run normalization passes 1-3, then generate SQL. Coordinate canonicalization for operator operands moved out of the emitter and - into the CanonicalizeCoordinates pass (issue #123), and DISTANCE generation - itself moved onto the registry's AST-expansion pass (epic #137, issue #140). - Emitter-level tests that pin canonicalized / expanded output must therefore - run all three passes before generating, exactly as - :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a - bare parsed AST. The expansion pass only touches operators that opt in - (``GIQL_EXPAND``); operators still on the legacy emitter (NEAREST, the - spatial predicates) pass through untouched. This helper is used where the - full ``transpile`` pipeline would otherwise rewrite the node away (a - column-to-column ``INTERSECTS`` is turned into a binned equi-join before the - predicate emitter runs). + into the CanonicalizeCoordinates pass (issue #123), and DISTANCE (issue #140) + and the spatial / set predicates (issue #141) moved onto the registry's + ExpandOperators pass (epic #137). Emitter-level tests that pin canonicalized / + expanded output must therefore run all three passes before generating, exactly + as :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a + bare parsed AST. Operators still on the legacy emitter (NEAREST, DISJOIN) pass + through the expansion pass untouched. This helper is used where the full + ``transpile`` pipeline would otherwise rewrite the node away (a column-to-column + ``INTERSECTS`` is turned into a binned equi-join before the predicate expander + runs). """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) @@ -822,38 +821,33 @@ def test_giqldistance_canonicalizes_closed_ends_apart_from_gap_parity( ) assert output == expected - def test_error_handling_invalid_range(self): + def test_expand_intersects_should_raise_when_invalid_range(self): """ GIVEN invalid genomic range string in Intersects - WHEN intersects_sql is called + WHEN the INTERSECTS predicate is expanded THEN ValueError with descriptive message is raised. """ sql = "SELECT * FROM variants WHERE interval INTERSECTS 'invalid'" - ast = parse_one(sql, dialect=GIQLDialect) - - generator = BaseGIQLGenerator() with pytest.raises(ValueError, match="Could not parse genomic range"): - generator.generate(ast) + _generate_through_passes(sql, Tables()) - def test_error_handling_unknown_operation(self): + def test_expand_intersects_should_raise_when_nonnumeric_range_bounds(self): """ - GIVEN unknown operation type in spatial operations - WHEN a spatial operation with unknown op_type is attempted - THEN ValueError is raised. + GIVEN an INTERSECTS range whose start/end bounds are non-numeric + WHEN the INTERSECTS predicate is expanded + THEN ValueError is raised from the range parse failure. - Note: This test verifies internal error handling by directly calling - a method with invalid input, which would only occur through code errors. + Note: 'chr:a-b' parses as a range shape but its bounds are not integers, + so the underlying RangeParser raises and the expander wraps it. (The + former "unknown operation" guard this exercised is now unreachable — + dispatch is closed over the three known op types — so this pins the + remaining reachable failure: a parse error on the literal range.) """ - # This is an indirect test - we verify the generator raises ValueError - # when given malformed range strings as that's how errors surface sql = "SELECT * FROM variants WHERE interval INTERSECTS 'chr:a-b'" - ast = parse_one(sql, dialect=GIQLDialect) - - generator = BaseGIQLGenerator() with pytest.raises(ValueError): - generator.generate(ast) + _generate_through_passes(sql, Tables()) def test_select_sql_join_without_alias(self, tables_with_two_tables): """ diff --git a/tests/integration/datafusion/test_cross_target_oracle.py b/tests/integration/datafusion/test_cross_target_oracle.py index 6b717a1..3d4b864 100644 --- a/tests/integration/datafusion/test_cross_target_oracle.py +++ b/tests/integration/datafusion/test_cross_target_oracle.py @@ -3,9 +3,9 @@ These exercise the reusable oracle (``tests/integration/conftest.py``) over the operators that already emit identical generic SQL across Generic and DataFusion and run correctly on DuckDB: INTERSECTS (literal + column-to-column join), -CONTAINS, WITHIN, and standalone NEAREST. No operator has been migrated to the -expander registry yet (epic #137), so this lane locks in the verification path -every later migration (#140-#144) will consume. +CONTAINS, WITHIN, and standalone NEAREST. The spatial predicates have since been +migrated to the expander registry (#141, epic #137); this lane locks in the +verification path that migration and every later one (#142-#144) consume. For the non-join operators (DISTANCE, CONTAINS, WITHIN, ANY/ALL, CLUSTER, MERGE) the generic and datafusion targets emit byte-identical SQL and both run diff --git a/tests/test_expander.py b/tests/test_expander.py index 196571b..ab93732 100644 --- a/tests/test_expander.py +++ b/tests/test_expander.py @@ -1412,27 +1412,28 @@ def test_opted_in_restores_flag_after_exception(self): assert GIQLMerge.GIQL_EXPAND is False -class TestIEJoinEarlyReturnSkipsExpansion: - """Pin Finding 2: the duckdb IEJoin early return skips the ExpandOperators pass.""" - - @pytest.mark.xfail( - strict=True, - reason="#141: the duckdb IEJoin early return in transpile() emits before " - "ExpandOperators runs, so a flagged operator on an IEJoin-eligible query " - "is not expanded. Flips to pass when #141 runs expansion before the " - "early return (or defers the IEJoin transformer to the registry).", - ) - def test_iejoin_query_expands_flagged_operator(self, clean_registry): - """Test that an IEJoin-eligible duckdb query expands a flagged operator. +class TestIEJoinRegistryDeferral: + """The duckdb IEJoin path defers to a target-specific Intersects expander (#141). + + Resolves Finding 2: the IEJoin early return used to emit before the + ExpandOperators pass, so a flagged operator on an IEJoin-eligible query was + never expanded. Now a *target-specific* ``(DuckDBTarget, Intersects)`` + registry entry overrides the built-in join strategy entirely (the public + extension hook), while the default duckdb path — with no such override — + still emits the built-in IEJoin SQL. + """ + + def test_iejoin_query_expands_target_override_expander(self, clean_registry): + """Test that a target-specific Intersects override fires on an IEJoin query. Given: A column-to-column INTERSECTS join eligible for the duckdb IEJoin - path, with Intersects flagged GIQL_EXPAND and an expander registered. + path, with a (DuckDBTarget, Intersects) expander registered. When: Transpiling with dialect='duckdb'. Then: - The expander's sentinel should appear (currently it does NOT — the - IEJoin early return skips the pass; this xfail flips when #141 lands). + The override expander's sentinel should appear — the IEJoin path + defers to the registry rather than short-circuiting expansion. """ # Arrange clean_registry.register( @@ -1444,41 +1445,35 @@ def test_iejoin_query_expands_flagged_operator(self, clean_registry): ) # Act - with _opted_in(Intersects): - sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") + sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") # Assert assert "__giql_iejoin_sentinel" in sql + assert "SET VARIABLE __giql_iejoin_" not in sql - def test_iejoin_query_emits_legacy_sql_unchanged(self, clean_registry): - """Test that the legacy IEJoin SQL is emitted regardless of a flagged op. + def test_iejoin_query_emits_builtin_iejoin_without_override(self): + """Test that the default duckdb path emits the built-in IEJoin SQL. Given: - The same IEJoin-eligible duckdb query with Intersects flagged and an - expander registered. + The same IEJoin-eligible duckdb query and no target-specific + Intersects override registered (only the built-in generic expander). When: Transpiling with dialect='duckdb'. Then: - The legacy IEJoin SET VARIABLE SQL is emitted and the expander's - sentinel is absent (characterizing the current skip; the companion - xfail surfaces when #141 fixes it). + The built-in IEJoin SET VARIABLE SQL is emitted (the generic + predicate expander does not disable the join strategy). """ # Arrange - clean_registry.register( - DuckDBTarget(), Intersects, lambda n, c: exp.column("__giql_iejoin_sentinel") - ) query = ( "SELECT a.start FROM peaks a " "JOIN genes b ON a.interval INTERSECTS b.interval" ) # Act - with _opted_in(Intersects): - sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") + sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") # Assert assert "SET VARIABLE __giql_iejoin_" in sql - assert "__giql_iejoin_sentinel" not in sql class TestTranspileExpanderDispatch: