From d2de294ad16b6f9d11a52d5815ebd52ba3c73d5b Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 23 Mar 2026 19:40:07 +0000 Subject: [PATCH 01/28] Begin the rewrite with new Node class and implement the fixed value --- src/ldlite/database/_expansion/rewrite.py | 193 ++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 src/ldlite/database/_expansion/rewrite.py diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py new file mode 100644 index 0000000..2136921 --- /dev/null +++ b/src/ldlite/database/_expansion/rewrite.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Literal, TypeAlias + +import duckdb +import psycopg +from psycopg import sql + +Conn: TypeAlias = duckdb.DuckDBPyConnection | psycopg.Connection + + +@dataclass(frozen=True) +class NodeContext: + source: sql.Identifier + column: sql.Identifier + prefixes: frozenset[str] + prop: str | None + + +class Node: + def __init__(self, ctx: NodeContext): + self.ctx = ctx + + +class FixedValueNode(Node): + @property + @abstractmethod + def alias(self) -> str: ... + + @property + @abstractmethod + def stmt(self) -> sql.Composed: ... + + +class TypedNode(FixedValueNode): + def __init__( + self, + ctx: NodeContext, + json_type: Literal["string", "number", "boolean"], + other_json_type: Literal["string", "number", "boolean"], + ): + super().__init__(ctx) + + self.is_mixed = json_type = other_json_type + self.json_type: Literal["string", "number", "boolean", "null"] = ( + "string" if self.is_mixed else json_type + ) + self.is_uuid = False + self.is_datetime = False + self.is_float = False + self.is_bigint = False + + @property + def alias(self) -> str: + return "__".join(self.ctx.prefixes) + ( + ("__" + self.ctx.prop) if self.ctx.prop is not None else "" + ) + + @property + def str_extract(self) -> sql.Composed: + path = sql.SQL("->").join([sql.Literal(p) for p in self.ctx.prefixes]) + if self.ctx.prop is None: + str_extract = ( + sql.SQL("""TRIM(BOTH '"' FROM ({json_col}""").format(self.ctx.column) + + path + + sql.SQL(")::text)") + ) + else: + str_extract = path + sql.SQL("->>{prop}").format(self.ctx.prop) + + return sql.Composed( + [ + sql.SQL("NULLIF(NULLIF("), + str_extract, + sql.SQL(", ''), 'null')"), + ], + ) + + @property + def stmt(self) -> sql.Composed: + str_extract = self.str_extract + + if self.json_type == "number" and self.is_float: + type_extract = str_extract + sql.SQL("::numeric") + elif self.json_type == "number" and self.is_bigint: + type_extract = str_extract + sql.SQL("::bigint") + elif self.json_type == "number": + type_extract = str_extract + sql.SQL("::integer") + elif self.json_type == "boolean": + type_extract = str_extract + sql.SQL("::bool") + elif self.json_type == "string" and self.is_uuid: + type_extract = str_extract + sql.SQL("::uuid") + elif self.json_type == "string" and self.is_datetime: + type_extract = str_extract + sql.SQL("::timestamptz") + else: + type_extract = str_extract + + return type_extract + sql.SQL(" AS {alias}").format(alias=self.alias) + + def specify_type(self, conn: Conn) -> None: + if self.is_mixed or self.json_type == "boolean": + return + + cte = ( + sql.SQL(""" +SELECT string_values AS MATERIALIZED ( + SELECT """) + + self.str_extract + + sql.SQL(""" AS string_value + FROM {source} +)""").format(source=self.ctx.source) + ) + + if self.json_type == "string": + with conn.cursor() as cur: + specify = cte + sql.SQL(""" +SELECT + NOT EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + string_value NOT LIKE '________-____-____-____-____________' + ) AS is_uuid + ,NOT EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + ( + string_value NOT LIKE '____-__-__T__:__:__.___' OR + string_value NOT LIKE '____-__-__T__:__:__.___+__:__' + ) + ) AS is_uuid;""") + cur.execute(specify.as_string()) + if row := cur.fetchone(): + (self.is_uuid, self.is_datetime) = row + return + + with conn.cursor() as cur: + specify = cte + sql.SQL(""" +SELECT + EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + string_value::numeric % 1 <> 0 + ) AS is_float + ,NOT EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + string_value::numeric > 2147483647 + ) AS is_bigint;""") + cur.execute(specify.as_string()) + if row := cur.fetchone(): + (self.is_float, self.is_bigint) = row + else: + self.json_type = "string" + + +class OrdinalNode(FixedValueNode): + @property + def alias(self) -> str: + return "__".join(self.ctx.prefixes) + "__o" + + @property + def stmt(self) -> sql.Composed: + return sql.SQL('"ordinality" AS {alias}').format(alias=self.alias) + + +class ArrayIdentityNode(FixedValueNode): + @property + def alias(self) -> str: + return "__id" + + @property + def stmt(self) -> sql.Composed: + return sql.SQL("ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS __id").format( + alias=self.alias, + ) + + +class RecursiveNode(Node): ... + + +class ObjectNode(RecursiveNode): ... + + +class RootNode(ObjectNode): ... + + +class ArrayNode(RecursiveNode): ... From d8dcbda6d3a3f55d44bff9f2ef966cb306bf5b6c Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 23 Mar 2026 20:26:44 +0000 Subject: [PATCH 02/28] Implement getting the keys and types for objects --- src/ldlite/database/_expansion/rewrite.py | 111 +++++++++++++++++++--- 1 file changed, 98 insertions(+), 13 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 2136921..c306426 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -1,28 +1,45 @@ from __future__ import annotations from abc import abstractmethod +from collections import deque from dataclasses import dataclass -from typing import Literal, TypeAlias +from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, cast import duckdb import psycopg from psycopg import sql +if TYPE_CHECKING: + from collections.abc import Iterator + Conn: TypeAlias = duckdb.DuckDBPyConnection | psycopg.Connection +JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean"] -@dataclass(frozen=True) +@dataclass class NodeContext: source: sql.Identifier column: sql.Identifier - prefixes: frozenset[str] + prefixes: list[str] prop: str | None + def sub_prefix(self, prefix: str | None, prop: str | None) -> NodeContext: + return NodeContext( + self.source, + self.column, + [*self.prefixes, *([prefix] if prefix is not None else [])], + prop, + ) + class Node: def __init__(self, ctx: NodeContext): self.ctx = ctx + @property + def path(self) -> sql.Composed: + return sql.SQL("->").join([sql.Literal(p) for p in self.ctx.prefixes]) + class FixedValueNode(Node): @property @@ -38,15 +55,13 @@ class TypedNode(FixedValueNode): def __init__( self, ctx: NodeContext, - json_type: Literal["string", "number", "boolean"], - other_json_type: Literal["string", "number", "boolean"], + json_type: JsonType, + other_json_type: JsonType, ): super().__init__(ctx) self.is_mixed = json_type = other_json_type - self.json_type: Literal["string", "number", "boolean", "null"] = ( - "string" if self.is_mixed else json_type - ) + self.json_type: JsonType = "string" if self.is_mixed else json_type self.is_uuid = False self.is_datetime = False self.is_float = False @@ -60,15 +75,14 @@ def alias(self) -> str: @property def str_extract(self) -> sql.Composed: - path = sql.SQL("->").join([sql.Literal(p) for p in self.ctx.prefixes]) if self.ctx.prop is None: str_extract = ( sql.SQL("""TRIM(BOTH '"' FROM ({json_col}""").format(self.ctx.column) - + path + + self.path + sql.SQL(")::text)") ) else: - str_extract = path + sql.SQL("->>{prop}").format(self.ctx.prop) + str_extract = self.path + sql.SQL("->>{prop}").format(self.ctx.prop) return sql.Composed( [ @@ -181,10 +195,81 @@ def stmt(self) -> sql.Composed: ) -class RecursiveNode(Node): ... +TNode = TypeVar("TNode", bound="Node") +TRode = TypeVar("TRode", bound="RecursiveNode") + + +class RecursiveNode(Node): + def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): + super().__init__(ctx) + + self.parent = parent + self._children: list[Node] = [] + + def _direct(self, cls: type[TNode]) -> Iterator[TNode]: + yield from [n for n in self._children if isinstance(n, cls)] + + def direct(self, cls: type[TNode]) -> list[TNode]: + return list(self._direct(cls)) + + def _descendents(self, cls: type[TRode]) -> Iterator[TRode]: + to_visit = deque([self]) + while to_visit: + n = to_visit.pop() + if isinstance(n, cls): + yield n + + to_visit.extend(n.direct(RecursiveNode)) + def descendents(self, cls: type[TRode]) -> list[TRode]: + return list(self._descendents(cls)) -class ObjectNode(RecursiveNode): ... + +class ObjectNode(RecursiveNode): + def load_columns(self, conn: Conn) -> None: + with conn.cursor() as cur: + key_discovery = ( + sql.SQL(""" +SELECT + j.json_key + ,MIN(j.json_type) AS json_type + ,MAX(j.json_type) AS other_json_type +FROM +( + SELECT + json_key + ,jsonb_typeof(json_value) AS json_type + ,ord + FROM {table} t + CROSS JOIN LATERAL jsonb_each(t.{column}""").format( + table=self.ctx.source, + json_column=self.ctx.column, + ) + + self.path + + ( + sql.SQL("->{prop})").format(prop=self.ctx.prop) + if self.ctx.prop is not None + else sql.SQL(")") + ) + + sql.Composed(""" WITH ORDINALITY k(json_key, json_value, ord) +) j +WHERE json_type <> 'null' +GROUP BY json_key +ORDER BY MAX(j.ord), COUNT(*) + """) + ) + cur.execute(key_discovery.as_string()) + for row in cur.fetchall(): + (key, jt, ojt) = cast("tuple[str, JsonType, JsonType]", row) + if jt == "array" and ojt == "array": + anode = ArrayNode(self.ctx.sub_prefix(self.ctx.prop, key), self) + self._children.append(anode) + elif jt == "object" and ojt == "object": + onode = ObjectNode(self.ctx.sub_prefix(self.ctx.prop, key), self) + self._children.append(onode) + else: + tnode = TypedNode(self.ctx.sub_prefix(self.ctx.prop, key), jt, ojt) + self._children.append(tnode) class RootNode(ObjectNode): ... From 25ae3e55f0f99604883cf225267d806beb29a39b Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 23 Mar 2026 22:44:10 +0000 Subject: [PATCH 03/28] Remove the identity node type --- src/ldlite/database/_expansion/rewrite.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index c306426..1ff1d71 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -180,19 +180,7 @@ def alias(self) -> str: @property def stmt(self) -> sql.Composed: - return sql.SQL('"ordinality" AS {alias}').format(alias=self.alias) - - -class ArrayIdentityNode(FixedValueNode): - @property - def alias(self) -> str: - return "__id" - - @property - def stmt(self) -> sql.Composed: - return sql.SQL("ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS __id").format( - alias=self.alias, - ) + return sql.SQL('"ordinality"::smallint AS {alias}').format(alias=self.alias) TNode = TypeVar("TNode", bound="Node") From 1c460d6dcdf07bc72026c49396ab6127806dafb2 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 00:12:56 +0000 Subject: [PATCH 04/28] Implement staging the array as a temp file and getting the type --- src/ldlite/database/_expansion/rewrite.py | 155 +++++++++++++++++----- 1 file changed, 125 insertions(+), 30 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 1ff1d71..11a0a24 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -4,6 +4,7 @@ from collections import deque from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, cast +from uuid import uuid4 import duckdb import psycopg @@ -31,6 +32,17 @@ def sub_prefix(self, prefix: str | None, prop: str | None) -> NodeContext: prop, ) + def array_prefix( + self, + source: sql.Identifier, + ) -> NodeContext: + return NodeContext( + source, + sql.Identifier("jsonb"), + [], + None, + ) + class Node: def __init__(self, ctx: NodeContext): @@ -82,7 +94,9 @@ def str_extract(self) -> sql.Composed: + sql.SQL(")::text)") ) else: - str_extract = self.path + sql.SQL("->>{prop}").format(self.ctx.prop) + str_extract = self.path + sql.SQL("->>{prop}").format( + sql.Literal(self.ctx.prop), + ) return sql.Composed( [ @@ -111,7 +125,9 @@ def stmt(self) -> sql.Composed: else: type_extract = str_extract - return type_extract + sql.SQL(" AS {alias}").format(alias=self.alias) + return type_extract + sql.SQL(" AS {alias}").format( + alias=sql.Identifier(self.alias), + ) def specify_type(self, conn: Conn) -> None: if self.is_mixed or self.json_type == "boolean": @@ -146,6 +162,7 @@ def specify_type(self, conn: Conn) -> None: string_value NOT LIKE '____-__-__T__:__:__.___+__:__' ) ) AS is_uuid;""") + cur.execute(specify.as_string()) if row := cur.fetchone(): (self.is_uuid, self.is_datetime) = row @@ -166,6 +183,7 @@ def specify_type(self, conn: Conn) -> None: string_value IS NOT NULL AND string_value::numeric > 2147483647 ) AS is_bigint;""") + cur.execute(specify.as_string()) if row := cur.fetchone(): (self.is_float, self.is_bigint) = row @@ -180,7 +198,9 @@ def alias(self) -> str: @property def stmt(self) -> sql.Composed: - return sql.SQL('"ordinality"::smallint AS {alias}').format(alias=self.alias) + return sql.SQL("__o AS {alias}").format( + alias=sql.Identifier(self.alias), + ) TNode = TypeVar("TNode", bound="Node") @@ -194,6 +214,24 @@ def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): self.parent = parent self._children: list[Node] = [] + @property + def source_cte(self) -> sql.Composed: + return ( + sql.SQL(""" +WITH source AS MATERIALIZED ( + SELECT + t.{column}""").format(column=self.ctx.column) + + self.path + + ( + sql.SQL("->{prop}").format(prop=sql.Literal(self.ctx.prop)) + if self.ctx.prop is not None + else sql.SQL("") + ) + + sql.SQL(""" AS ld_value + FROM {source} t +)""").format(source=self.ctx.source) + ) + def _direct(self, cls: type[TNode]) -> Iterator[TNode]: yield from [n for n in self._children if isinstance(n, cls)] @@ -216,36 +254,24 @@ def descendents(self, cls: type[TRode]) -> list[TRode]: class ObjectNode(RecursiveNode): def load_columns(self, conn: Conn) -> None: with conn.cursor() as cur: - key_discovery = ( - sql.SQL(""" + key_discovery = self.source_cte + sql.Composed(""" SELECT - j.json_key - ,MIN(j.json_type) AS json_type - ,MAX(j.json_type) AS other_json_type -FROM -( + json_key + ,MIN(json_type) AS json_type + ,MAX(json_type) AS other_json_type +FROM ( SELECT - json_key - ,jsonb_typeof(json_value) AS json_type - ,ord - FROM {table} t - CROSS JOIN LATERAL jsonb_each(t.{column}""").format( - table=self.ctx.source, - json_column=self.ctx.column, - ) - + self.path - + ( - sql.SQL("->{prop})").format(prop=self.ctx.prop) - if self.ctx.prop is not None - else sql.SQL(")") - ) - + sql.Composed(""" WITH ORDINALITY k(json_key, json_value, ord) + j."key" AS json_key + ,jsonb_typeof(j."value") AS json_type + ,j.ord + FROM source t + CROSS JOIN LATERAL jsonb_each(t.ld_value) WITH ORDINALITY j("key", "value", ord) + WHERE jsonb_typeof(t.ld_value) = 'object' ) j WHERE json_type <> 'null' GROUP BY json_key -ORDER BY MAX(j.ord), COUNT(*) - """) - ) +ORDER BY MAX(j.ord), COUNT(*);""") + cur.execute(key_discovery.as_string()) for row in cur.fetchall(): (key, jt, ojt) = cast("tuple[str, JsonType, JsonType]", row) @@ -260,7 +286,76 @@ def load_columns(self, conn: Conn) -> None: self._children.append(tnode) -class RootNode(ObjectNode): ... +class RootNode(ObjectNode): + def __init__(self, source: sql.Identifier): + super().__init__( + NodeContext( + source, + sql.Identifier("jsonb"), + [], + None, + ), + None, + ) + + +class ArrayNode(RecursiveNode): + def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): + super().__init__(ctx, parent) + self.temp = sql.Identifier(str(uuid4()).split("-")[0]) + + def make_temp(self, conn: Conn) -> Node | None: + with conn.cursor() as cur: + expansion = ( + sql.SQL("CREATE TEMPORARY TABLE {temp} AS").format(temp=self.temp) + + self.source_cte + + sql.Composed(""" +SELECT + __id AS parent__id + ,(ROW_NUMBER() OVER (ORDER BY (SELECT NULL)))::integer AS __id + ,ord::smallint AS __o + ,jsonb + ,json_type +FROM ( + SELECT + t.__id + ,a."value" AS jsonb + ,jsonb_typeof(a."value") AS json_type + ,a.ord + FROM source t + CROSS JOIN LATERAL jsonb_each(t.ld_value) WITH ORDINALITY a("value", ord) + WHERE jsonb_typeof(t.ld_value) = 'array' +) a +WHERE json_type <> 'null' +""") + ) + cur.execute(expansion.as_string()) + + type_discovery = sql.SQL(""" +SELECT + MIN(json_type) AS json_type + ,MAX(json_type) AS other_json_type +FROM {temp}""").format(temp=self.temp) + + cur.execute(type_discovery.as_string()) + self._children.append(OrdinalNode(self.ctx.array_prefix(self.temp))) + if row := cur.fetchone(): + (jt, ojt) = cast("tuple[JsonType, JsonType]", row) + node: Node + if jt == "array" and ojt == "array": + node = ArrayNode(self.ctx.array_prefix(self.temp), self) + self._children.append(node) + elif jt == "object" and ojt == "object": + node = ObjectNode(self.ctx.array_prefix(self.temp), self) + self._children.append(node) + else: + node = TypedNode( + self.ctx.array_prefix(self.temp), + jt, + ojt, + ) + self._children.append(node) + return node -class ArrayNode(RecursiveNode): ... + return None From c1c2356e84e57d9f649176760e74375f4187be06 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 16:12:28 +0000 Subject: [PATCH 05/28] Implement the new simplified transform algorithm --- src/ldlite/database/_expansion/rewrite.py | 61 ++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 11a0a24..dcc5190 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -218,7 +218,7 @@ def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): def source_cte(self) -> sql.Composed: return ( sql.SQL(""" -WITH source AS MATERIALIZED ( +WITH source ( SELECT t.{column}""").format(column=self.ctx.column) + self.path @@ -250,6 +250,20 @@ def _descendents(self, cls: type[TRode]) -> Iterator[TRode]: def descendents(self, cls: type[TRode]) -> list[TRode]: return list(self._descendents(cls)) + def _typed_columns(self) -> Iterator[TypedNode]: + for n in self._descendents(RecursiveNode): + yield from n.direct(TypedNode) + + def typed_nodes(self) -> list[TypedNode]: + return list(self._typed_columns()) + + def remove(self, node: RecursiveNode) -> None: + self._children.remove(node) + + @property + def create_statement(self) -> sql.Composed: + return sql.Composed("") + class ObjectNode(RecursiveNode): def load_columns(self, conn: Conn) -> None: @@ -359,3 +373,48 @@ def make_temp(self, conn: Conn) -> Node | None: return node return None + + +def _non_srs_statements( + conn: Conn, + source_table: sql.Identifier, +) -> Iterator[sql.Composed]: + # Here be dragons! The nodes have inner state manipulations + # that violate the space/time continuum: + # * o.load_columns + # * a.make_temp + # * t.specify_type + # These all are expected to be called before generating the sql + # as they load/prepare database information. + # Because building up to the transformation statements takes a long time + # we're doing all that work up front to keep the time + # that a transaction is opened to a minimum (which is a leaky abstraction). + + root = RootNode(source_table) + onodes: deque[ObjectNode] = deque([root]) + while o := onodes.popleft(): + o.load_columns(conn) + onodes.extend(o.direct(ObjectNode)) + anodes = deque(o.direct(ArrayNode)) + while a := anodes.popleft(): + if n := a.make_temp(conn): + if isinstance(n, ObjectNode): + onodes.append(n) + if isinstance(n, ArrayNode): + anodes.append(n) + else: + cast("RecursiveNode", a.parent).remove(a) + + for t in root.typed_nodes(): + t.specify_type(conn) + + yield root.create_statement + for a in root.descendents(ArrayNode): + yield a.create_statement + + +def non_srs_statements( + conn: Conn, + source_table: sql.Identifier, +) -> list[sql.Composed]: + return list(_non_srs_statements(conn, source_table)) From 755d631704defe3e624768bcb36853afce360f06 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 16:26:24 +0000 Subject: [PATCH 06/28] Track scan progress --- src/ldlite/database/_expansion/rewrite.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index dcc5190..415744c 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Literal, TypeVar, cast from uuid import uuid4 import duckdb @@ -12,6 +12,9 @@ if TYPE_CHECKING: from collections.abc import Iterator + from typing import NoReturn, TypeAlias + + from tqdm import tqdm Conn: TypeAlias = duckdb.DuckDBPyConnection | psycopg.Connection JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean"] @@ -378,6 +381,7 @@ def make_temp(self, conn: Conn) -> Node | None: def _non_srs_statements( conn: Conn, source_table: sql.Identifier, + scan_progress: tqdm[NoReturn], ) -> Iterator[sql.Composed]: # Here be dragons! The nodes have inner state manipulations # that violate the space/time continuum: @@ -394,6 +398,9 @@ def _non_srs_statements( onodes: deque[ObjectNode] = deque([root]) while o := onodes.popleft(): o.load_columns(conn) + scan_progress.total += len(o.direct(Node)) + scan_progress.update(1) + onodes.extend(o.direct(ObjectNode)) anodes = deque(o.direct(ArrayNode)) while a := anodes.popleft(): @@ -402,11 +409,15 @@ def _non_srs_statements( onodes.append(n) if isinstance(n, ArrayNode): anodes.append(n) + scan_progress.total += 1 else: cast("RecursiveNode", a.parent).remove(a) + scan_progress.update(1) + for t in root.typed_nodes(): t.specify_type(conn) + scan_progress.update(1) yield root.create_statement for a in root.descendents(ArrayNode): @@ -416,5 +427,6 @@ def _non_srs_statements( def non_srs_statements( conn: Conn, source_table: sql.Identifier, + scan_progress: tqdm[NoReturn], ) -> list[sql.Composed]: - return list(_non_srs_statements(conn, source_table)) + return list(_non_srs_statements(conn, source_table, scan_progress)) From 0427c607eb65a4c83820be58f46b37bd8d461c48 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 16:35:35 +0000 Subject: [PATCH 07/28] Refactor sql statement generation location --- src/ldlite/database/_expansion/rewrite.py | 28 +++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 415744c..5041ff0 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import deque from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, TypeVar, cast @@ -263,10 +263,6 @@ def typed_nodes(self) -> list[TypedNode]: def remove(self, node: RecursiveNode) -> None: self._children.remove(node) - @property - def create_statement(self) -> sql.Composed: - return sql.Composed("") - class ObjectNode(RecursiveNode): def load_columns(self, conn: Conn) -> None: @@ -303,7 +299,13 @@ def load_columns(self, conn: Conn) -> None: self._children.append(tnode) -class RootNode(ObjectNode): +class StampableNode(ABC): + @property + @abstractmethod + def create_statement(self) -> sql.Composed: ... + + +class RootNode(ObjectNode, StampableNode): def __init__(self, source: sql.Identifier): super().__init__( NodeContext( @@ -315,8 +317,12 @@ def __init__(self, source: sql.Identifier): None, ) + @property + def create_statement(self) -> sql.Composed: + return sql.Composed("") + -class ArrayNode(RecursiveNode): +class ArrayNode(RecursiveNode, StampableNode): def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): super().__init__(ctx, parent) self.temp = sql.Identifier(str(uuid4()).split("-")[0]) @@ -377,6 +383,10 @@ def make_temp(self, conn: Conn) -> Node | None: return None + @property + def create_statement(self) -> sql.Composed: + return sql.Composed("") + def _non_srs_statements( conn: Conn, @@ -391,8 +401,8 @@ def _non_srs_statements( # These all are expected to be called before generating the sql # as they load/prepare database information. # Because building up to the transformation statements takes a long time - # we're doing all that work up front to keep the time - # that a transaction is opened to a minimum (which is a leaky abstraction). + # we're doing all that work up front to keep the time that + # a transaction is opened to a minimum (which is a leaky abstraction). root = RootNode(source_table) onodes: deque[ObjectNode] = deque([root]) From a016de6d23fa6cd58622a83a7eb77c025468ed10 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 17:02:39 +0000 Subject: [PATCH 08/28] Share more of the json object traversal code --- src/ldlite/database/_expansion/rewrite.py | 164 +++++++++++----------- 1 file changed, 81 insertions(+), 83 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 5041ff0..8edfeba 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -55,6 +55,39 @@ def __init__(self, ctx: NodeContext): def path(self) -> sql.Composed: return sql.SQL("->").join([sql.Literal(p) for p in self.ctx.prefixes]) + @property + def _json_source(self) -> sql.Composed: + return self.ctx.column + sql.SQL("->").join( + [sql.Literal(p) for p in self.ctx.prefixes], + ) + + @property + def json_value(self) -> sql.Composed: + if self.ctx.prop is None: + return self._json_source + return self._json_source + sql.Composed("->") + sql.Literal(self.ctx.prop) + + @property + def json_string(self) -> sql.Composed: + if self.ctx.prop is None: + str_extract = ( + sql.Composed("""TRIM(BOTH '"' FROM """) + + self._json_source + + sql.Composed(")::text)") + ) + else: + str_extract = ( + self._json_source + sql.Composed("->>") + sql.Literal(self.ctx.prop) + ) + + return sql.Composed( + [ + sql.SQL("NULLIF(NULLIF("), + str_extract, + sql.SQL(", ''), 'null')"), + ], + ) + class FixedValueNode(Node): @property @@ -88,61 +121,37 @@ def alias(self) -> str: ("__" + self.ctx.prop) if self.ctx.prop is not None else "" ) - @property - def str_extract(self) -> sql.Composed: - if self.ctx.prop is None: - str_extract = ( - sql.SQL("""TRIM(BOTH '"' FROM ({json_col}""").format(self.ctx.column) - + self.path - + sql.SQL(")::text)") - ) - else: - str_extract = self.path + sql.SQL("->>{prop}").format( - sql.Literal(self.ctx.prop), - ) - - return sql.Composed( - [ - sql.SQL("NULLIF(NULLIF("), - str_extract, - sql.SQL(", ''), 'null')"), - ], - ) - @property def stmt(self) -> sql.Composed: - str_extract = self.str_extract - if self.json_type == "number" and self.is_float: - type_extract = str_extract + sql.SQL("::numeric") + type_extract = self.json_string + sql.SQL("::numeric") elif self.json_type == "number" and self.is_bigint: - type_extract = str_extract + sql.SQL("::bigint") + type_extract = self.json_string + sql.SQL("::bigint") elif self.json_type == "number": - type_extract = str_extract + sql.SQL("::integer") + type_extract = self.json_string + sql.SQL("::integer") elif self.json_type == "boolean": - type_extract = str_extract + sql.SQL("::bool") + type_extract = self.json_string + sql.SQL("::bool") elif self.json_type == "string" and self.is_uuid: - type_extract = str_extract + sql.SQL("::uuid") + type_extract = self.json_string + sql.SQL("::uuid") elif self.json_type == "string" and self.is_datetime: - type_extract = str_extract + sql.SQL("::timestamptz") + type_extract = self.json_string + sql.SQL("::timestamptz") else: - type_extract = str_extract + type_extract = self.json_string - return type_extract + sql.SQL(" AS {alias}").format( - alias=sql.Identifier(self.alias), - ) + return type_extract + sql.Composed(" AS ") + sql.Identifier(self.alias) def specify_type(self, conn: Conn) -> None: if self.is_mixed or self.json_type == "boolean": return cte = ( - sql.SQL(""" + sql.Composed(""" SELECT string_values AS MATERIALIZED ( SELECT """) - + self.str_extract + + self.json_string + sql.SQL(""" AS string_value FROM {source} + WHERE string_value IS NOT NULL )""").format(source=self.ctx.source) ) @@ -152,18 +161,15 @@ def specify_type(self, conn: Conn) -> None: SELECT NOT EXISTS( SELECT 1 FROM string_values - WHERE - string_value IS NOT NULL AND - string_value NOT LIKE '________-____-____-____-____________' + WHERE string_value NOT LIKE '________-____-____-____-____________' ) AS is_uuid ,NOT EXISTS( SELECT 1 FROM string_values WHERE - string_value IS NOT NULL AND - ( - string_value NOT LIKE '____-__-__T__:__:__.___' OR - string_value NOT LIKE '____-__-__T__:__:__.___+__:__' - ) + ( + string_value NOT LIKE '____-__-__T__:__:__.___' AND + string_value NOT LIKE '____-__-__T__:__:__.___+__:__' + ) ) AS is_uuid;""") cur.execute(specify.as_string()) @@ -176,15 +182,11 @@ def specify_type(self, conn: Conn) -> None: SELECT EXISTS( SELECT 1 FROM string_values - WHERE - string_value IS NOT NULL AND - string_value::numeric % 1 <> 0 + WHERE string_value::numeric % 1 <> 0 ) AS is_float - ,NOT EXISTS( + ,EXISTS( SELECT 1 FROM string_values - WHERE - string_value IS NOT NULL AND - string_value::numeric > 2147483647 + WHERE string_value::numeric > 2147483647 ) AS is_bigint;""") cur.execute(specify.as_string()) @@ -217,24 +219,6 @@ def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): self.parent = parent self._children: list[Node] = [] - @property - def source_cte(self) -> sql.Composed: - return ( - sql.SQL(""" -WITH source ( - SELECT - t.{column}""").format(column=self.ctx.column) - + self.path - + ( - sql.SQL("->{prop}").format(prop=sql.Literal(self.ctx.prop)) - if self.ctx.prop is not None - else sql.SQL("") - ) - + sql.SQL(""" AS ld_value - FROM {source} t -)""").format(source=self.ctx.source) - ) - def _direct(self, cls: type[TNode]) -> Iterator[TNode]: yield from [n for n in self._children if isinstance(n, cls)] @@ -267,23 +251,32 @@ def remove(self, node: RecursiveNode) -> None: class ObjectNode(RecursiveNode): def load_columns(self, conn: Conn) -> None: with conn.cursor() as cur: - key_discovery = self.source_cte + sql.Composed(""" + key_discovery = ( + sql.Composed(""" SELECT json_key ,MIN(json_type) AS json_type ,MAX(json_type) AS other_json_type FROM ( SELECT - j."key" AS json_key - ,jsonb_typeof(j."value") AS json_type - ,j.ord - FROM source t - CROSS JOIN LATERAL jsonb_each(t.ld_value) WITH ORDINALITY j("key", "value", ord) - WHERE jsonb_typeof(t.ld_value) = 'object' -) j + k."key" AS json_key + ,jsonb_typeof(k."value") AS json_type + ,k.ord + FROM + ( + SELECT """) + + self.json_value + + sql.SQL(""" AS ld_value + FROM {source} + WHERE jsonb_typeof(ld_value) = 'object' + ) j + CROSS JOIN LATERAL jsonb_each(j.ld_value) WITH ORDINALITY k("key", "value", ord) +) key_discovery WHERE json_type <> 'null' GROUP BY json_key -ORDER BY MAX(j.ord), COUNT(*);""") +ORDER BY MAX(j.ord), COUNT(*); +""").format(source=self.ctx.source) + ) cur.execute(key_discovery.as_string()) for row in cur.fetchall(): @@ -331,7 +324,6 @@ def make_temp(self, conn: Conn) -> Node | None: with conn.cursor() as cur: expansion = ( sql.SQL("CREATE TEMPORARY TABLE {temp} AS").format(temp=self.temp) - + self.source_cte + sql.Composed(""" SELECT __id AS parent__id @@ -345,12 +337,18 @@ def make_temp(self, conn: Conn) -> Node | None: ,a."value" AS jsonb ,jsonb_typeof(a."value") AS json_type ,a.ord - FROM source t - CROSS JOIN LATERAL jsonb_each(t.ld_value) WITH ORDINALITY a("value", ord) - WHERE jsonb_typeof(t.ld_value) = 'array' -) a + FROM + ( + SELECT """) + + self.json_value + + sql.SQL(""" AS ld_value + FROM {source} + WHERE jsonb_typeof(ld_value) = 'array' + ) j + CROSS JOIN LATERAL jsonb_array_elements(j.ld_value) WITH ORDINALITY a("value", ord) +) expansion WHERE json_type <> 'null' -""") +""").format(source=self.ctx.source) ) cur.execute(expansion.as_string()) From dd40621481c1600fc9e9ccdfd2906db50bf06cc2 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 17:22:31 +0000 Subject: [PATCH 09/28] Skeleton out the create table sql statements --- src/ldlite/database/_expansion/rewrite.py | 56 +++++++++++++++++++---- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 8edfeba..1626b84 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -11,7 +11,7 @@ from psycopg import sql if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from typing import NoReturn, TypeAlias from tqdm import tqdm @@ -295,7 +295,8 @@ def load_columns(self, conn: Conn) -> None: class StampableNode(ABC): @property @abstractmethod - def create_statement(self) -> sql.Composed: ... + # The Callable construct is necessary until DuckDB implements CTAS RETURNING + def create_statement(self) -> Callable[[sql.SQL], sql.Composed]: ... class RootNode(ObjectNode, StampableNode): @@ -311,8 +312,26 @@ def __init__(self, source: sql.Identifier): ) @property - def create_statement(self) -> sql.Composed: - return sql.Composed("") + def create_statement(self) -> Callable[[sql.SQL], sql.Composed]: + def create_root_table(source_stmt: sql.SQL) -> sql.Composed: + return ( + sql.SQL(""" +CREATE OR REPLACE TABLE {output_table} AS +WITH root_source AS ( +""").format(output_table=sql.Identifier("")) + + source_stmt.format(source=self.ctx.source) + + sql.Composed( + """ +) +SELECT + """, + ) + + sql.SQL("\n ,").join([sql.Identifier("")]) + + sql.Composed(""" +FROM root_source""") + ) + + return create_root_table class ArrayNode(RecursiveNode, StampableNode): @@ -382,15 +401,36 @@ def make_temp(self, conn: Conn) -> Node | None: return None @property - def create_statement(self) -> sql.Composed: - return sql.Composed("") + def create_statement(self) -> Callable[[sql.SQL], sql.Composed]: + def create_array_table(source_stmt: sql.SQL) -> sql.Composed: + return ( + sql.SQL(""" +CREATE OR REPLACE TABLE {output_table} AS +WITH array_source AS ( +""").format(output_table=sql.Identifier("")) + + source_stmt.format(source=self.temp) + + sql.Composed( + """ +) +SELECT + """, + ) + + sql.SQL("\n ,").join([sql.Identifier("")]) + + sql.SQL(""" +FROM array_source a +JOIN {parent} p ON + a.p__id = p.__id; +""").format(parent=sql.Identifier("")) + ) + + return create_array_table def _non_srs_statements( conn: Conn, source_table: sql.Identifier, scan_progress: tqdm[NoReturn], -) -> Iterator[sql.Composed]: +) -> Iterator[Callable[[sql.SQL], sql.Composed]]: # Here be dragons! The nodes have inner state manipulations # that violate the space/time continuum: # * o.load_columns @@ -436,5 +476,5 @@ def non_srs_statements( conn: Conn, source_table: sql.Identifier, scan_progress: tqdm[NoReturn], -) -> list[sql.Composed]: +) -> list[Callable[[sql.SQL], sql.Composed]]: return list(_non_srs_statements(conn, source_table, scan_progress)) From 154644f0632ada429b0335172275f6ff8e731bb4 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 18:13:59 +0000 Subject: [PATCH 10/28] Build sql with output tables and columns --- src/ldlite/database/_expansion/rewrite.py | 82 +++++++++++++++++------ 1 file changed, 62 insertions(+), 20 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 1626b84..9f1c03a 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -237,12 +237,12 @@ def _descendents(self, cls: type[TRode]) -> Iterator[TRode]: def descendents(self, cls: type[TRode]) -> list[TRode]: return list(self._descendents(cls)) - def _typed_columns(self) -> Iterator[TypedNode]: + def _typed_nodes(self) -> Iterator[TypedNode]: for n in self._descendents(RecursiveNode): yield from n.direct(TypedNode) def typed_nodes(self) -> list[TypedNode]: - return list(self._typed_columns()) + return list(self._typed_nodes()) def remove(self, node: RecursiveNode) -> None: self._children.remove(node) @@ -296,11 +296,15 @@ class StampableNode(ABC): @property @abstractmethod # The Callable construct is necessary until DuckDB implements CTAS RETURNING - def create_statement(self) -> Callable[[sql.SQL], sql.Composed]: ... + def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: ... class RootNode(ObjectNode, StampableNode): - def __init__(self, source: sql.Identifier): + def __init__( + self, + source: sql.Identifier, + get_output_table: Callable[[str | None], tuple[str, sql.Identifier]], + ): super().__init__( NodeContext( source, @@ -310,28 +314,37 @@ def __init__(self, source: sql.Identifier): ), None, ) + self.get_output_table = get_output_table @property - def create_statement(self) -> Callable[[sql.SQL], sql.Composed]: + def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: + (output_table_name, output_table) = self.get_output_table(None) + def create_root_table(source_stmt: sql.SQL) -> sql.Composed: return ( sql.SQL(""" CREATE OR REPLACE TABLE {output_table} AS WITH root_source AS ( -""").format(output_table=sql.Identifier("")) - + source_stmt.format(source=self.ctx.source) +""").format(output_table=output_table) + + source_stmt.format(source_table=self.ctx.source) + sql.Composed( """ ) SELECT """, ) - + sql.SQL("\n ,").join([sql.Identifier("")]) + + sql.SQL("\n ,").join( + [ + t.stmt + for o in self.descendents(ObjectNode) + for t in o.direct(TypedNode) + ], + ) + sql.Composed(""" FROM root_source""") ) - return create_root_table + return (output_table_name, create_root_table) class ArrayNode(RecursiveNode, StampableNode): @@ -401,36 +414,64 @@ def make_temp(self, conn: Conn) -> Node | None: return None @property - def create_statement(self) -> Callable[[sql.SQL], sql.Composed]: + def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: + p: RecursiveNode | None = self + parents: list[RecursiveNode] = [] + while p is not None and (p := p.parent): + parents.append(p) + root = cast("RootNode", parents[-1]) + (output_table_name, output_table) = root.get_output_table( + "__" + "__".join(self.ctx.prefixes), + ) + (_, parent_table) = root.get_output_table( + "__" + "__".join(cast("Node", parents[0]).ctx.prefixes), + ) + def create_array_table(source_stmt: sql.SQL) -> sql.Composed: return ( sql.SQL(""" CREATE OR REPLACE TABLE {output_table} AS WITH array_source AS ( -""").format(output_table=sql.Identifier("")) - + source_stmt.format(source=self.temp) +""").format(output_table=output_table) + + source_stmt.format(source_table=self.temp) + sql.Composed( """ ) SELECT """, ) - + sql.SQL("\n ,").join([sql.Identifier("")]) + + sql.SQL("\n ,").join( + [ + sql.Composed("a.__id"), + *[ + t.alias + for p in reversed(parents) + for t in p.direct(TypedNode) + ], + *[t.stmt for t in self.direct(TypedNode)], + *[ + t.stmt + for o in self.descendents(ObjectNode) + for t in o.direct(TypedNode) + ], + ], + ) + sql.SQL(""" FROM array_source a -JOIN {parent} p ON +JOIN {parent_table} p ON a.p__id = p.__id; -""").format(parent=sql.Identifier("")) +""").format(parent_table=parent_table) ) - return create_array_table + return (output_table_name, create_array_table) def _non_srs_statements( conn: Conn, source_table: sql.Identifier, + output_table: Callable[[str | None], tuple[str, sql.Identifier]], scan_progress: tqdm[NoReturn], -) -> Iterator[Callable[[sql.SQL], sql.Composed]]: +) -> Iterator[tuple[str, Callable[[sql.SQL], sql.Composed]]]: # Here be dragons! The nodes have inner state manipulations # that violate the space/time continuum: # * o.load_columns @@ -442,7 +483,7 @@ def _non_srs_statements( # we're doing all that work up front to keep the time that # a transaction is opened to a minimum (which is a leaky abstraction). - root = RootNode(source_table) + root = RootNode(source_table, output_table) onodes: deque[ObjectNode] = deque([root]) while o := onodes.popleft(): o.load_columns(conn) @@ -475,6 +516,7 @@ def _non_srs_statements( def non_srs_statements( conn: Conn, source_table: sql.Identifier, + output_table: Callable[[str | None], tuple[str, sql.Identifier]], scan_progress: tqdm[NoReturn], -) -> list[Callable[[sql.SQL], sql.Composed]]: - return list(_non_srs_statements(conn, source_table, scan_progress)) +) -> list[tuple[str, Callable[[sql.SQL], sql.Composed]]]: + return list(_non_srs_statements(conn, source_table, output_table, scan_progress)) From 8e00886b7aa3eaa88993d0d6cbb3409fca29043e Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 18:18:48 +0000 Subject: [PATCH 11/28] Clear the indexing time when the transform finishes --- src/ldlite/database/_typed_database.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index 6560c1b..30adb75 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -422,6 +422,7 @@ def _transform_complete( "final_rowcount" = $2 ,"transform_complete" = $3 ,"transform_time" = $4 + ,"index_time" = NULL ,"data_refresh_start" = "load_start" ,"data_refresh_end" = "download_complete" WHERE "table_prefix" = $1 From b3bed23e442d38daf023387e834e798224622c15 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 18:55:06 +0000 Subject: [PATCH 12/28] Refactor to use the rewrite expansion --- src/ldlite/database/_duckdb.py | 4 +- src/ldlite/database/_expansion/rewrite.py | 4 +- src/ldlite/database/_postgres.py | 7 +- src/ldlite/database/_prefix.py | 4 +- src/ldlite/database/_typed_database.py | 104 +++++++++------------- 5 files changed, 51 insertions(+), 72 deletions(-) diff --git a/src/ldlite/database/_duckdb.py b/src/ldlite/database/_duckdb.py index de0324d..d49c3fa 100644 --- a/src/ldlite/database/_duckdb.py +++ b/src/ldlite/database/_duckdb.py @@ -86,8 +86,8 @@ def ingest_records( return total - def source_table_cte_stmt(self, keep_source: bool) -> str: # noqa: ARG002 - return "WITH ld_source AS (SELECT * FROM {source_table})" + def source_stmt(self, keep_source: bool) -> sql.SQL: # noqa: ARG002 + return sql.SQL("SELECT * FROM {source_table}") # DuckDB has some strong opinions about cursors that are different than postgres diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 9f1c03a..99f2871 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -267,7 +267,7 @@ def load_columns(self, conn: Conn) -> None: SELECT """) + self.json_value + sql.SQL(""" AS ld_value - FROM {source} + FROM {source_table} WHERE jsonb_typeof(ld_value) = 'object' ) j CROSS JOIN LATERAL jsonb_each(j.ld_value) WITH ORDINALITY k("key", "value", ord) @@ -275,7 +275,7 @@ def load_columns(self, conn: Conn) -> None: WHERE json_type <> 'null' GROUP BY json_key ORDER BY MAX(j.ord), COUNT(*); -""").format(source=self.ctx.source) +""").format(source_table=self.ctx.source) ) cur.execute(key_discovery.as_string()) diff --git a/src/ldlite/database/_postgres.py b/src/ldlite/database/_postgres.py index 88cbc42..3edcece 100644 --- a/src/ldlite/database/_postgres.py +++ b/src/ldlite/database/_postgres.py @@ -103,7 +103,8 @@ def preprocess_source_table( ), ) - def source_table_cte_stmt(self, keep_source: bool) -> str: + def source_stmt(self, keep_source: bool) -> sql.SQL: if keep_source: - return "WITH ld_source AS (SELECT * FROM {source_table})" - return "WITH ld_source AS (DELETE FROM {source_table} RETURNING *)" + return sql.SQL("SELECT * FROM {source_table}") + + return sql.SQL("DELETE FROM {source_table} RETURNING *") diff --git a/src/ldlite/database/_prefix.py b/src/ldlite/database/_prefix.py index a93dd81..456083e 100644 --- a/src/ldlite/database/_prefix.py +++ b/src/ldlite/database/_prefix.py @@ -34,9 +34,9 @@ def raw_table(self) -> PrefixedTable: def _output_table(self) -> str: return self._prefix + "__t" - def output_table(self, prefix: str) -> PrefixedTable: + def output_table(self, prefix: str | None) -> PrefixedTable: return self._prefixed_table( - self._output_table + ("" if len(prefix) == 0 else "__" + prefix), + self._output_table + ("" if prefix is None else "__" + prefix), ) @property diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index 30adb75..7de63e2 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -1,4 +1,3 @@ -# pyright: reportArgumentType=false from __future__ import annotations from abc import abstractmethod @@ -12,8 +11,7 @@ from tqdm import tqdm from . import Database -from ._expansion import expand_nonmarc -from ._expansion.context import ExpandContext +from ._expansion.rewrite import non_srs_statements from ._prefix import Prefix if TYPE_CHECKING: @@ -193,7 +191,7 @@ def preprocess_source_table( # https://github.com/duckdb/duckdb/issues/3417 # Only postgres supports it which is why we have an abstraction here @abstractmethod - def source_table_cte_stmt(self, keep_source: bool) -> str: ... + def source_stmt(self, keep_source: bool) -> sql.SQL: ... def expand_prefix( self, @@ -215,84 +213,64 @@ def expand_prefix( conn.commit() return [] - with closing(self._conn_factory(False)) as conn: - with conn.cursor() as cur: - cur.execute( - sql.SQL( - """ -CREATE TEMP TABLE {dest_table} AS -""" - + self.source_table_cte_stmt(keep_source=keep_raw) - + """ -SELECT * from ld_source; -""", - ) - .format( - dest_table=pfx.origin_table, - source_table=pfx.raw_table.id, - ) - .as_string(), - ) + transform_progress = ( + transform_progress + if transform_progress is not None + else tqdm(disable=True, total=0) + ) + transform_progress.total = 1 - tables_to_create = expand_nonmarc( - "jsonb", - ["__id"], - ExpandContext( - conn, - pfx.origin_table, - json_depth, - pfx.transform_table, - pfx.output_table, - self.preprocess_source_table, # type: ignore [arg-type] - self.source_table_cte_stmt, - scan_progress if scan_progress is not None else tqdm(disable=True), - transform_progress - if transform_progress is not None - else tqdm(disable=True), - ), + with closing(self._conn_factory(False)) as conn: + tables_to_create = non_srs_statements( + conn, + pfx.raw_table[1], + pfx.output_table, + scan_progress + if scan_progress is not None + else tqdm(disable=True, total=0), ) + transform_progress.total += len(tables_to_create) + 1 + transform_progress.update(1) with self._begin(conn): self._drop_extracted_tables(conn, pfx) + with conn.cursor() as cur: + for i, (_, table) in enumerate(tables_to_create): + create_table = table( + self.source_stmt(keep_source=(i == 0 and keep_raw)), + ) + cur.execute(create_table.as_string()) + transform_progress.update(1) + + # duckdb can't drop the raw table when creating the output table if not keep_raw: self._drop_raw_table(conn, pfx) - with conn.cursor() as cur: - for table in tables_to_create: - cur.execute(table[1].as_string()) + total = 0 with conn.cursor() as cur: - cur.execute( - sql.SQL( - """ -CREATE TABLE {catalog_table} ( - table_name text -) -""", - ) - .format(catalog_table=pfx.catalog_table.id) - .as_string(), - ) - total = 0 + create_catalog = sql.SQL( + """CREATE TABLE {catalog_table} (table_name text)""", + ).format(catalog_table=pfx.catalog_table.id) + cur.execute(create_catalog.as_string()) if len(tables_to_create) > 0: + insert_catalog = sql.SQL( + "INSERT INTO {catalog_table} VALUES ($1)", + ).format(catalog_table=pfx.catalog_table.id) cur.executemany( - sql.SQL("INSERT INTO {catalog_table} VALUES ($1)") - .format( - catalog_table=pfx.catalog_table.id, - ) - .as_string(), + insert_catalog.as_string(), [(pfx.catalog_table_row(t[0]),) for t in tables_to_create], ) - cur.execute( - sql.SQL("SELECT COUNT(*) FROM {table}") - .format(table=pfx.output_table("").id) - .as_string(), + count = sql.SQL("SELECT COUNT(*) FROM {table}").format( + table=pfx.output_table("").id, ) + cur.execute(count.as_string()) total = cast("tuple[int]", cur.fetchone())[0] + transform_progress.update(1) self._transform_complete(conn, pfx, total, transform_started) - return [t[0] for t in tables_to_create] + return [pfx.catalog_table_row(t[0]) for t in tables_to_create] def index_prefix(self, prefix: str, progress: tqdm[NoReturn] | None = None) -> None: pfx = Prefix(prefix) From 04b41c9e2fffc255bbc123a3e69c7e8044a7c34c Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 18:59:32 +0000 Subject: [PATCH 13/28] Use postgres sql format properly --- src/ldlite/database/_expansion/rewrite.py | 32 +++++++++-------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 99f2871..ca420fb 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -65,28 +65,22 @@ def _json_source(self) -> sql.Composed: def json_value(self) -> sql.Composed: if self.ctx.prop is None: return self._json_source - return self._json_source + sql.Composed("->") + sql.Literal(self.ctx.prop) + return self._json_source + sql.SQL("->") + sql.Literal(self.ctx.prop) @property def json_string(self) -> sql.Composed: if self.ctx.prop is None: str_extract = ( - sql.Composed("""TRIM(BOTH '"' FROM """) + sql.SQL("""TRIM(BOTH '"' FROM """) + self._json_source - + sql.Composed(")::text)") + + sql.SQL(")::text)") ) else: str_extract = ( - self._json_source + sql.Composed("->>") + sql.Literal(self.ctx.prop) + self._json_source + sql.SQL("->>") + sql.Literal(self.ctx.prop) ) - return sql.Composed( - [ - sql.SQL("NULLIF(NULLIF("), - str_extract, - sql.SQL(", ''), 'null')"), - ], - ) + return sql.SQL("NULLIF(NULLIF(") + str_extract + sql.SQL(", ''), 'null')") class FixedValueNode(Node): @@ -138,14 +132,14 @@ def stmt(self) -> sql.Composed: else: type_extract = self.json_string - return type_extract + sql.Composed(" AS ") + sql.Identifier(self.alias) + return type_extract + sql.SQL(" AS ") + sql.Identifier(self.alias) def specify_type(self, conn: Conn) -> None: if self.is_mixed or self.json_type == "boolean": return cte = ( - sql.Composed(""" + sql.SQL(""" SELECT string_values AS MATERIALIZED ( SELECT """) + self.json_string @@ -252,7 +246,7 @@ class ObjectNode(RecursiveNode): def load_columns(self, conn: Conn) -> None: with conn.cursor() as cur: key_discovery = ( - sql.Composed(""" + sql.SQL(""" SELECT json_key ,MIN(json_type) AS json_type @@ -327,7 +321,7 @@ def create_root_table(source_stmt: sql.SQL) -> sql.Composed: WITH root_source AS ( """).format(output_table=output_table) + source_stmt.format(source_table=self.ctx.source) - + sql.Composed( + + sql.SQL( """ ) SELECT @@ -340,7 +334,7 @@ def create_root_table(source_stmt: sql.SQL) -> sql.Composed: for t in o.direct(TypedNode) ], ) - + sql.Composed(""" + + sql.SQL(""" FROM root_source""") ) @@ -356,7 +350,7 @@ def make_temp(self, conn: Conn) -> Node | None: with conn.cursor() as cur: expansion = ( sql.SQL("CREATE TEMPORARY TABLE {temp} AS").format(temp=self.temp) - + sql.Composed(""" + + sql.SQL(""" SELECT __id AS parent__id ,(ROW_NUMBER() OVER (ORDER BY (SELECT NULL)))::integer AS __id @@ -434,7 +428,7 @@ def create_array_table(source_stmt: sql.SQL) -> sql.Composed: WITH array_source AS ( """).format(output_table=output_table) + source_stmt.format(source_table=self.temp) - + sql.Composed( + + sql.SQL( """ ) SELECT @@ -442,7 +436,7 @@ def create_array_table(source_stmt: sql.SQL) -> sql.Composed: ) + sql.SQL("\n ,").join( [ - sql.Composed("a.__id"), + sql.Identifier("a", "__id"), *[ t.alias for p in reversed(parents) From 3cf5217882b5d1dfb1b4d7164dda041cef7c5ac6 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 19:14:07 +0000 Subject: [PATCH 14/28] Add a duckdb jsonb_each shim --- src/ldlite/database/_duckdb.py | 8 ++++---- tests/test_json_operators.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ldlite/database/_duckdb.py b/src/ldlite/database/_duckdb.py index d49c3fa..1b5334c 100644 --- a/src/ldlite/database/_duckdb.py +++ b/src/ldlite/database/_duckdb.py @@ -41,13 +41,13 @@ def __init__(self, db: duckdb.DuckDBPyConnection) -> None: END ; -CREATE OR REPLACE FUNCTION jsonb_object_keys(j) AS TABLE - SELECT je.key as ld_key, id as "ordinality" FROM json_each(j) je ORDER BY je.id -; - CREATE OR REPLACE FUNCTION jsonb_array_elements(j) AS TABLE ( SELECT value as ld_value, rowid + 1 AS "ordinality" FROM main.json_each(j) ); + +CREATE OR REPLACE FUNCTION jsonb_each(j) AS TABLE ( + SELECT key, value, rowid AS "ordinality" FROM main.json_each(j) +) """, ) diff --git a/tests/test_json_operators.py b/tests/test_json_operators.py index ac43355..75a2672 100644 --- a/tests/test_json_operators.py +++ b/tests/test_json_operators.py @@ -27,7 +27,7 @@ class JsonTC: assertion_params: tuple[Any, ...] -def case_jsonb_object_keys() -> JsonTC: +def case_jsonb_each() -> JsonTC: return JsonTC( """ {assertion} @@ -35,7 +35,7 @@ def case_jsonb_object_keys() -> JsonTC: FROM (SELECT 'k1' jkey UNION SELECT 'k2' jkey) as e FULL OUTER JOIN ( SELECT k.ld_key as jkey - FROM j, jsonb_object_keys(j.jc->'obj') k(ld_key) + FROM j, jsonb_each(j.jc->'obj') k(ld_key) ) as a USING (jkey) WHERE e.jkey IS NULL or a.jkey IS NULL) as q;""", From 60d158e50ff5ae39e6a144c277472968266c8105 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 19:25:26 +0000 Subject: [PATCH 15/28] Various syntax fixes --- src/ldlite/database/_expansion/rewrite.py | 63 ++++++++++++----------- tests/test_expansion.py | 7 ++- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index ca420fb..ffb6d58 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -102,7 +102,7 @@ def __init__( ): super().__init__(ctx) - self.is_mixed = json_type = other_json_type + self.is_mixed = json_type != other_json_type self.json_type: JsonType = "string" if self.is_mixed else json_type self.is_uuid = False self.is_datetime = False @@ -140,12 +140,11 @@ def specify_type(self, conn: Conn) -> None: cte = ( sql.SQL(""" -SELECT string_values AS MATERIALIZED ( +WITH string_values AS MATERIALIZED ( SELECT """) + self.json_string + sql.SQL(""" AS string_value FROM {source} - WHERE string_value IS NOT NULL )""").format(source=self.ctx.source) ) @@ -155,15 +154,18 @@ def specify_type(self, conn: Conn) -> None: SELECT NOT EXISTS( SELECT 1 FROM string_values - WHERE string_value NOT LIKE '________-____-____-____-____________' + WHERE + string_value IS NOT NULL AND + string_value NOT LIKE '________-____-____-____-____________' ) AS is_uuid ,NOT EXISTS( SELECT 1 FROM string_values WHERE - ( - string_value NOT LIKE '____-__-__T__:__:__.___' AND - string_value NOT LIKE '____-__-__T__:__:__.___+__:__' - ) + string_value IS NOT NULL AND + ( + string_value NOT LIKE '____-__-__T__:__:__.___' AND + string_value NOT LIKE '____-__-__T__:__:__.___+__:__' + ) ) AS is_uuid;""") cur.execute(specify.as_string()) @@ -176,11 +178,15 @@ def specify_type(self, conn: Conn) -> None: SELECT EXISTS( SELECT 1 FROM string_values - WHERE string_value::numeric % 1 <> 0 + WHERE + string_value IS NOT NULL AND + string_value::numeric % 1 <> 0 ) AS is_float ,EXISTS( SELECT 1 FROM string_values - WHERE string_value::numeric > 2147483647 + WHERE + string_value IS NOT NULL AND + string_value::numeric > 2147483647 ) AS is_bigint;""") cur.execute(specify.as_string()) @@ -251,7 +257,8 @@ def load_columns(self, conn: Conn) -> None: json_key ,MIN(json_type) AS json_type ,MAX(json_type) AS other_json_type -FROM ( +FROM +( SELECT k."key" AS json_key ,jsonb_typeof(k."value") AS json_type @@ -260,15 +267,14 @@ def load_columns(self, conn: Conn) -> None: ( SELECT """) + self.json_value - + sql.SQL(""" AS ld_value - FROM {source_table} - WHERE jsonb_typeof(ld_value) = 'object' + + sql.SQL(""" AS ld_value FROM {source_table} ) j CROSS JOIN LATERAL jsonb_each(j.ld_value) WITH ORDINALITY k("key", "value", ord) + WHERE jsonb_typeof(j.ld_value) = 'object' ) key_discovery WHERE json_type <> 'null' GROUP BY json_key -ORDER BY MAX(j.ord), COUNT(*); +ORDER BY MAX(ord), COUNT(*); """).format(source_table=self.ctx.source) ) @@ -317,13 +323,11 @@ def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: def create_root_table(source_stmt: sql.SQL) -> sql.Composed: return ( sql.SQL(""" -CREATE OR REPLACE TABLE {output_table} AS -WITH root_source AS ( -""").format(output_table=output_table) +CREATE TABLE {output_table} AS +WITH root_source AS (""").format(output_table=output_table) + source_stmt.format(source_table=self.ctx.source) + sql.SQL( - """ -) + """) SELECT """, ) @@ -367,11 +371,10 @@ def make_temp(self, conn: Conn) -> Node | None: ( SELECT """) + self.json_value - + sql.SQL(""" AS ld_value - FROM {source} - WHERE jsonb_typeof(ld_value) = 'array' + + sql.SQL(""" AS ld_value FROM {source} ) j CROSS JOIN LATERAL jsonb_array_elements(j.ld_value) WITH ORDINALITY a("value", ord) + WHERE jsonb_typeof(j.ld_value) = 'array' ) expansion WHERE json_type <> 'null' """).format(source=self.ctx.source) @@ -424,13 +427,11 @@ def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: def create_array_table(source_stmt: sql.SQL) -> sql.Composed: return ( sql.SQL(""" -CREATE OR REPLACE TABLE {output_table} AS -WITH array_source AS ( -""").format(output_table=output_table) +CREATE TABLE {output_table} AS +WITH array_source AS (""").format(output_table=output_table) + source_stmt.format(source_table=self.temp) + sql.SQL( - """ -) + """) SELECT """, ) @@ -479,14 +480,16 @@ def _non_srs_statements( root = RootNode(source_table, output_table) onodes: deque[ObjectNode] = deque([root]) - while o := onodes.popleft(): + while onodes: + o = onodes.popleft() o.load_columns(conn) scan_progress.total += len(o.direct(Node)) scan_progress.update(1) onodes.extend(o.direct(ObjectNode)) anodes = deque(o.direct(ArrayNode)) - while a := anodes.popleft(): + while anodes: + a = anodes.popleft() if n := a.make_temp(conn): if isinstance(n, ObjectNode): onodes.append(n) diff --git a/tests/test_expansion.py b/tests/test_expansion.py index 094dfa4..9efed25 100644 --- a/tests/test_expansion.py +++ b/tests/test_expansion.py @@ -52,6 +52,7 @@ def case_typed_columns() -> ExpansionTC: "timestamptz": "2028-01-23T00:00:00.000+00:00", "integer": 1, "numeric": 1.2, + "bigint": 1774374169585, "text": "value", "boolean": false, "uuid": "88888888-8888-1888-8888-888888888888" @@ -60,9 +61,10 @@ def case_typed_columns() -> ExpansionTC: b""" { "id": "id2", - "timestamptz": "2025-06-20T17:37:58.675+00:00", + "timestamptz": "2025-06-20T17:37:58.675", "integer": 2, - "numeric": 2.3, + "numeric": 2, + "bigint": 2, "text": "00000000-0000-1000-A000-000000000000", "boolean": false, "uuid": "11111111-1111-1111-8111-111111111111" @@ -82,6 +84,7 @@ def case_typed_columns() -> ExpansionTC: for a in [ ("integer", "integer", "INTEGER"), ("numeric", "numeric", "DECIMAL(18,3)"), + ("bigint", "bigint", "BIGINT"), ("text", "text", "VARCHAR"), ("uuid", "uuid", "UUID"), ("boolean", "boolean", "BOOLEAN"), From df0f2f831ab0ffa5a86641e180e021ecff996e19 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 19:31:48 +0000 Subject: [PATCH 16/28] Fix aliasing for basic datatypes column test --- src/ldlite/database/_expansion/rewrite.py | 5 ++++- src/ldlite/database/_typed_database.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index ffb6d58..6630774 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -111,8 +111,11 @@ def __init__( @property def alias(self) -> str: + if len(self.ctx.prefixes) == 0: + return self.ctx.prop if self.ctx.prop is not None else "" + return "__".join(self.ctx.prefixes) + ( - ("__" + self.ctx.prop) if self.ctx.prop is not None else "" + ("_" + self.ctx.prop) if self.ctx.prop is not None else "" ) @property diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index 7de63e2..9f412e1 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -262,7 +262,7 @@ def expand_prefix( ) count = sql.SQL("SELECT COUNT(*) FROM {table}").format( - table=pfx.output_table("").id, + table=pfx.output_table(None).id, ) cur.execute(count.as_string()) total = cast("tuple[int]", cur.fetchone())[0] From 72449e9308896a5069644101e99e9f1b2b669ec5 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 19:37:29 +0000 Subject: [PATCH 17/28] Fix snake case naming --- src/ldlite/database/_expansion/rewrite.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 6630774..7affaa5 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -27,6 +27,18 @@ class NodeContext: prefixes: list[str] prop: str | None + @property + def snake(self) -> str | None: + if self.prop is None: + return None + + snake = "".join("_" + c.lower() if c.isupper() else c for c in self.prop) + # there's also sorts of weird edge cases here that don't come up in practice + if (naked := self.prop.lstrip("_")) and len(naked) > 0 and naked[0].isupper(): + snake = snake.removeprefix("_") + + return snake + def sub_prefix(self, prefix: str | None, prop: str | None) -> NodeContext: return NodeContext( self.source, @@ -112,10 +124,10 @@ def __init__( @property def alias(self) -> str: if len(self.ctx.prefixes) == 0: - return self.ctx.prop if self.ctx.prop is not None else "" + return self.ctx.snake if self.ctx.snake is not None else "" return "__".join(self.ctx.prefixes) + ( - ("_" + self.ctx.prop) if self.ctx.prop is not None else "" + ("_" + self.ctx.snake) if self.ctx.snake is not None else "" ) @property From bc7c490c261770f92fbc63cbc3121857616e195c Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Tue, 24 Mar 2026 19:57:47 +0000 Subject: [PATCH 18/28] Fix basic object expansion --- src/ldlite/database/_expansion/rewrite.py | 51 +++++++++++++++-------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 7affaa5..0bd3cd1 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -27,18 +27,25 @@ class NodeContext: prefixes: list[str] prop: str | None - @property - def snake(self) -> str | None: - if self.prop is None: - return None - - snake = "".join("_" + c.lower() if c.isupper() else c for c in self.prop) + @staticmethod + def _snake(not_snake: str) -> str: + snake = "".join("_" + c.lower() if c.isupper() else c for c in not_snake) # there's also sorts of weird edge cases here that don't come up in practice - if (naked := self.prop.lstrip("_")) and len(naked) > 0 and naked[0].isupper(): + if (naked := not_snake.lstrip("_")) and len(naked) > 0 and naked[0].isupper(): snake = snake.removeprefix("_") return snake + @property + def snake_prop(self) -> str | None: + if self.prop is None: + return None + return self._snake(self.prop) + + @property + def snake_prefixes(self) -> list[str]: + return [self._snake(p) for p in self.prefixes] + def sub_prefix(self, prefix: str | None, prop: str | None) -> NodeContext: return NodeContext( self.source, @@ -68,13 +75,20 @@ def path(self) -> sql.Composed: return sql.SQL("->").join([sql.Literal(p) for p in self.ctx.prefixes]) @property - def _json_source(self) -> sql.Composed: - return self.ctx.column + sql.SQL("->").join( - [sql.Literal(p) for p in self.ctx.prefixes], + def _json_source(self) -> sql.Composable: + if len(self.ctx.prefixes) == 0: + return self.ctx.column + + return ( + self.ctx.column + + sql.SQL("->") + + sql.SQL("->").join( + [sql.Literal(p) for p in self.ctx.prefixes], + ) ) @property - def json_value(self) -> sql.Composed: + def json_value(self) -> sql.Composable: if self.ctx.prop is None: return self._json_source return self._json_source + sql.SQL("->") + sql.Literal(self.ctx.prop) @@ -124,10 +138,10 @@ def __init__( @property def alias(self) -> str: if len(self.ctx.prefixes) == 0: - return self.ctx.snake if self.ctx.snake is not None else "" + return self.ctx.snake_prop if self.ctx.snake_prop is not None else "" - return "__".join(self.ctx.prefixes) + ( - ("_" + self.ctx.snake) if self.ctx.snake is not None else "" + return "__".join(self.ctx.snake_prefixes) + ( + ("__" + self.ctx.snake_prop) if self.ctx.snake_prop is not None else "" ) @property @@ -348,9 +362,12 @@ def create_root_table(source_stmt: sql.SQL) -> sql.Composed: ) + sql.SQL("\n ,").join( [ - t.stmt - for o in self.descendents(ObjectNode) - for t in o.direct(TypedNode) + sql.Identifier("__id"), + *[ + t.stmt + for o in self.descendents(ObjectNode) + for t in o.direct(TypedNode) + ], ], ) + sql.SQL(""" From 5d55799fd127c1cdfe59347e4274fba35b4cfee8 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Wed, 25 Mar 2026 14:09:30 +0000 Subject: [PATCH 19/28] WIP: Making array expansion work --- src/ldlite/database/_expansion/rewrite.py | 88 +++++++++++++---------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 0bd3cd1..5e6cf8c 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from collections import deque -from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, TypeVar, cast from uuid import uuid4 @@ -20,12 +19,19 @@ JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean"] -@dataclass class NodeContext: - source: sql.Identifier - column: sql.Identifier - prefixes: list[str] - prop: str | None + def __init__( + self, + source: sql.Identifier, + column: sql.Identifier, + prefixes: list[str], + prop: str | None, + ): + self.source = source + self.column = column + self.prefixes = prefixes + self.path = prefixes + self.prop = prop @staticmethod def _snake(not_snake: str) -> str: @@ -57,13 +63,16 @@ def sub_prefix(self, prefix: str | None, prop: str | None) -> NodeContext: def array_prefix( self, source: sql.Identifier, + prefix: str | None, ) -> NodeContext: - return NodeContext( + context = NodeContext( source, sql.Identifier("jsonb"), - [], + [*self.prefixes, *([prefix] if prefix is not None else [])], None, ) + context.path = cast("list[str]", []) + return context class Node: @@ -72,20 +81,14 @@ def __init__(self, ctx: NodeContext): @property def path(self) -> sql.Composed: - return sql.SQL("->").join([sql.Literal(p) for p in self.ctx.prefixes]) + return sql.SQL("->").join([sql.Literal(p) for p in self.ctx.path]) @property def _json_source(self) -> sql.Composable: - if len(self.ctx.prefixes) == 0: + if len(self.ctx.path) == 0: return self.ctx.column - return ( - self.ctx.column - + sql.SQL("->") - + sql.SQL("->").join( - [sql.Literal(p) for p in self.ctx.prefixes], - ) - ) + return self.ctx.column + sql.SQL("->") + self.path @property def json_value(self) -> sql.Composable: @@ -99,7 +102,7 @@ def json_string(self) -> sql.Composed: str_extract = ( sql.SQL("""TRIM(BOTH '"' FROM """) + self._json_source - + sql.SQL(")::text)") + + sql.SQL("::text)") ) else: str_extract = ( @@ -111,12 +114,12 @@ def json_string(self) -> sql.Composed: class FixedValueNode(Node): @property - @abstractmethod - def alias(self) -> str: ... + def alias(self) -> str: + return "" @property - @abstractmethod - def stmt(self) -> sql.Composed: ... + def stmt(self) -> sql.Composed: + return sql.Composed("") class TypedNode(FixedValueNode): @@ -228,11 +231,11 @@ def specify_type(self, conn: Conn) -> None: class OrdinalNode(FixedValueNode): @property def alias(self) -> str: - return "__".join(self.ctx.prefixes) + "__o" + return "__".join(self.ctx.snake_prefixes) + "__o" @property def stmt(self) -> sql.Composed: - return sql.SQL("__o AS {alias}").format( + return sql.Identifier("a", "__o") + sql.SQL(" AS {alias}").format( alias=sql.Identifier(self.alias), ) @@ -388,14 +391,14 @@ def make_temp(self, conn: Conn) -> Node | None: sql.SQL("CREATE TEMPORARY TABLE {temp} AS").format(temp=self.temp) + sql.SQL(""" SELECT - __id AS parent__id + __id AS p__id ,(ROW_NUMBER() OVER (ORDER BY (SELECT NULL)))::integer AS __id ,ord::smallint AS __o ,jsonb ,json_type FROM ( SELECT - t.__id + j.__id ,a."value" AS jsonb ,jsonb_typeof(a."value") AS json_type ,a.ord @@ -403,7 +406,7 @@ def make_temp(self, conn: Conn) -> Node | None: ( SELECT """) + self.json_value - + sql.SQL(""" AS ld_value FROM {source} + + sql.SQL(""" AS ld_value, __id FROM {source} ) j CROSS JOIN LATERAL jsonb_array_elements(j.ld_value) WITH ORDINALITY a("value", ord) WHERE jsonb_typeof(j.ld_value) = 'array' @@ -420,19 +423,27 @@ def make_temp(self, conn: Conn) -> Node | None: FROM {temp}""").format(temp=self.temp) cur.execute(type_discovery.as_string()) - self._children.append(OrdinalNode(self.ctx.array_prefix(self.temp))) + self._children.append( + OrdinalNode(self.ctx.array_prefix(self.temp, self.ctx.prop)), + ) if row := cur.fetchone(): (jt, ojt) = cast("tuple[JsonType, JsonType]", row) node: Node if jt == "array" and ojt == "array": - node = ArrayNode(self.ctx.array_prefix(self.temp), self) + node = ArrayNode( + self.ctx.array_prefix(self.temp, self.ctx.prop), + self, + ) self._children.append(node) elif jt == "object" and ojt == "object": - node = ObjectNode(self.ctx.array_prefix(self.temp), self) + node = ObjectNode( + self.ctx.array_prefix(self.temp, self.ctx.prop), + self, + ) self._children.append(node) else: node = TypedNode( - self.ctx.array_prefix(self.temp), + self.ctx.array_prefix(self.temp, self.ctx.prop), jt, ojt, ) @@ -450,11 +461,14 @@ def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: parents.append(p) root = cast("RootNode", parents[-1]) (output_table_name, output_table) = root.get_output_table( - "__" + "__".join(self.ctx.prefixes), - ) - (_, parent_table) = root.get_output_table( - "__" + "__".join(cast("Node", parents[0]).ctx.prefixes), + "__".join(self.ctx.prefixes) + (self.ctx.prop or ""), ) + if parents[0] == parents[-1]: + (_, parent_table) = root.get_output_table(None) + else: + (_, parent_table) = root.get_output_table( + "__" + "__".join(cast("Node", parents[0]).ctx.prefixes), + ) def create_array_table(source_stmt: sql.SQL) -> sql.Composed: return ( @@ -471,11 +485,11 @@ def create_array_table(source_stmt: sql.SQL) -> sql.Composed: [ sql.Identifier("a", "__id"), *[ - t.alias + sql.Identifier("p", t.alias) for p in reversed(parents) for t in p.direct(TypedNode) ], - *[t.stmt for t in self.direct(TypedNode)], + *[t.stmt for t in self.direct(FixedValueNode)], *[ t.stmt for o in self.descendents(ObjectNode) From 1f521f29d9da8afac31a863b660194ce03c4ec18 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Wed, 25 Mar 2026 15:45:06 +0000 Subject: [PATCH 20/28] Infer NodeContext from Node structure --- src/ldlite/database/_expansion/rewrite.py | 316 ++++++++++++---------- 1 file changed, 166 insertions(+), 150 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 5e6cf8c..86a51a5 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -19,136 +19,85 @@ JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean"] -class NodeContext: - def __init__( - self, - source: sql.Identifier, - column: sql.Identifier, - prefixes: list[str], - prop: str | None, - ): +class Node: + def __init__(self, source: sql.Identifier, prop: str | None): self.source = source - self.column = column - self.prefixes = prefixes - self.path = prefixes self.prop = prop + self.snake_prop: str | None = None - @staticmethod - def _snake(not_snake: str) -> str: - snake = "".join("_" + c.lower() if c.isupper() else c for c in not_snake) - # there's also sorts of weird edge cases here that don't come up in practice - if (naked := not_snake.lstrip("_")) and len(naked) > 0 and naked[0].isupper(): - snake = snake.removeprefix("_") - - return snake + if self.prop is not None: + self.snake_prop = "".join( + "_" + c.lower() if c.isupper() else c for c in self.prop + ) - @property - def snake_prop(self) -> str | None: - if self.prop is None: - return None - return self._snake(self.prop) + # there's also sorts of weird edge cases here that don't come up in practice + if ( + (naked := self.prop.lstrip("_")) + and len(naked) > 0 + and naked[0].isupper() + ): + self.snake_prop = self.snake_prop.removeprefix("_") - @property - def snake_prefixes(self) -> list[str]: - return [self._snake(p) for p in self.prefixes] - - def sub_prefix(self, prefix: str | None, prop: str | None) -> NodeContext: - return NodeContext( - self.source, - self.column, - [*self.prefixes, *([prefix] if prefix is not None else [])], - prop, - ) - def array_prefix( +class FixedValueNode(Node): + def __init__( self, source: sql.Identifier, - prefix: str | None, - ) -> NodeContext: - context = NodeContext( - source, - sql.Identifier("jsonb"), - [*self.prefixes, *([prefix] if prefix is not None else [])], - None, - ) - context.path = cast("list[str]", []) - return context - - -class Node: - def __init__(self, ctx: NodeContext): - self.ctx = ctx - - @property - def path(self) -> sql.Composed: - return sql.SQL("->").join([sql.Literal(p) for p in self.ctx.path]) - - @property - def _json_source(self) -> sql.Composable: - if len(self.ctx.path) == 0: - return self.ctx.column - - return self.ctx.column + sql.SQL("->") + self.path - - @property - def json_value(self) -> sql.Composable: - if self.ctx.prop is None: - return self._json_source - return self._json_source + sql.SQL("->") + sql.Literal(self.ctx.prop) - - @property - def json_string(self) -> sql.Composed: - if self.ctx.prop is None: - str_extract = ( - sql.SQL("""TRIM(BOTH '"' FROM """) - + self._json_source - + sql.SQL("::text)") - ) - else: - str_extract = ( - self._json_source + sql.SQL("->>") + sql.Literal(self.ctx.prop) - ) - - return sql.SQL("NULLIF(NULLIF(") + str_extract + sql.SQL(", ''), 'null')") + prop: str | None, + path: sql.Composable, + prefix: str, + ): + super().__init__(source, prop) + self.path = path + self.prefix = prefix -class FixedValueNode(Node): @property def alias(self) -> str: - return "" + if len(self.prefix) == 0: + return self.snake_prop or "" + return self.prefix + ( + ("__" + self.snake_prop) if self.snake_prop is not None else "" + ) @property - def stmt(self) -> sql.Composed: - return sql.Composed("") + def stmt(self) -> sql.Composable: + # this should be abstract but Python can't use ABCs as a generic + return sql.SQL("") class TypedNode(FixedValueNode): def __init__( self, - ctx: NodeContext, - json_type: JsonType, - other_json_type: JsonType, + source: sql.Identifier, + prop: str | None, + path: sql.Composable, + prefix: str, + json_types: tuple[JsonType, JsonType], ): - super().__init__(ctx) + super().__init__(source, prop, path, prefix) - self.is_mixed = json_type != other_json_type - self.json_type: JsonType = "string" if self.is_mixed else json_type + self.is_mixed = json_types[0] != json_types[1] + self.json_type: JsonType = "string" if self.is_mixed else json_types[0] self.is_uuid = False self.is_datetime = False self.is_float = False self.is_bigint = False @property - def alias(self) -> str: - if len(self.ctx.prefixes) == 0: - return self.ctx.snake_prop if self.ctx.snake_prop is not None else "" + def json_string(self) -> sql.Composable: + if self.prop is None: + str_extract = ( + sql.SQL("""TRIM(BOTH '"' FROM (""") + self.path + sql.SQL(")::text)") + ) + else: + str_extract = self.path + sql.SQL("->>") + sql.Literal(self.prop) - return "__".join(self.ctx.snake_prefixes) + ( - ("__" + self.ctx.snake_prop) if self.ctx.snake_prop is not None else "" - ) + return sql.SQL("NULLIF(NULLIF(") + str_extract + sql.SQL(", ''), 'null')") @property - def stmt(self) -> sql.Composed: + def stmt(self) -> sql.Composable: + type_extract: sql.Composable if self.json_type == "number" and self.is_float: type_extract = self.json_string + sql.SQL("::numeric") elif self.json_type == "number" and self.is_bigint: @@ -177,7 +126,7 @@ def specify_type(self, conn: Conn) -> None: + self.json_string + sql.SQL(""" AS string_value FROM {source} -)""").format(source=self.ctx.source) +)""").format(source=self.source) ) if self.json_type == "string": @@ -229,12 +178,20 @@ def specify_type(self, conn: Conn) -> None: class OrdinalNode(FixedValueNode): + def __init__( + self, + source: sql.Identifier, + path: sql.Composable, + prefix: str, + ): + super().__init__(source, None, path, prefix) + @property def alias(self) -> str: - return "__".join(self.ctx.snake_prefixes) + "__o" + return self.prefix + "__o" @property - def stmt(self) -> sql.Composed: + def stmt(self) -> sql.Composable: return sql.Identifier("a", "__o") + sql.SQL(" AS {alias}").format( alias=sql.Identifier(self.alias), ) @@ -245,12 +202,72 @@ def stmt(self) -> sql.Composed: class RecursiveNode(Node): - def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): - super().__init__(ctx) + def __init__( + self, + source: sql.Identifier, + prop: str | None, + column: sql.Identifier, + parent: RecursiveNode | None, + ): + super().__init__(source, prop) self.parent = parent + self.column = column self._children: list[Node] = [] + def _parents(self) -> Iterator[RecursiveNode]: + p = self.parent + while p is not None: + yield p + p = p.parent + + @property + def parents(self) -> list[RecursiveNode]: + return list(self._parents()) + + @property + def table_parent(self) -> RecursiveNode: + for p in self.parents: + if isinstance(p, (ArrayNode, RootNode)): + return p + + # There's "always" a root node + return None # type: ignore[return-value] + + @property + def path(self) -> sql.Composable: + prop_accessor: sql.Composable + if self.prop is None: + prop_accessor = sql.SQL("") + else: + prop_accessor = sql.SQL("->") + sql.Literal(self.prop) + + path: list[str] = [] + for p in self.parents: + if isinstance(p, (ArrayNode, RootNode)): + break + if p.prop is not None: + path.append(p.prop) + + if len(path) == 0: + return self.column + prop_accessor + + return ( + self.column + + sql.SQL("->") + + sql.SQL("->").join([sql.Literal(p) for p in reversed(path)]) + + prop_accessor + ) + + @property + def prefix(self) -> str: + if len(self.parents) == 0 or isinstance(self.parents[0], RootNode): + return self.snake_prop or "" + + return "__".join( + [p.snake_prop for p in reversed(self.parents) if p.snake_prop is not None], + ) + (("__" + self.snake_prop) if self.snake_prop is not None else "") + def _direct(self, cls: type[TNode]) -> Iterator[TNode]: yield from [n for n in self._children if isinstance(n, cls)] @@ -298,7 +315,7 @@ def load_columns(self, conn: Conn) -> None: FROM ( SELECT """) - + self.json_value + + self.path + sql.SQL(""" AS ld_value FROM {source_table} ) j CROSS JOIN LATERAL jsonb_each(j.ld_value) WITH ORDINALITY k("key", "value", ord) @@ -307,43 +324,46 @@ def load_columns(self, conn: Conn) -> None: WHERE json_type <> 'null' GROUP BY json_key ORDER BY MAX(ord), COUNT(*); -""").format(source_table=self.ctx.source) +""").format(source_table=self.source) ) cur.execute(key_discovery.as_string()) for row in cur.fetchall(): (key, jt, ojt) = cast("tuple[str, JsonType, JsonType]", row) if jt == "array" and ojt == "array": - anode = ArrayNode(self.ctx.sub_prefix(self.ctx.prop, key), self) + anode = ArrayNode(self.source, key, self.column, self) self._children.append(anode) elif jt == "object" and ojt == "object": - onode = ObjectNode(self.ctx.sub_prefix(self.ctx.prop, key), self) + onode = ObjectNode(self.source, key, self.column, self) self._children.append(onode) else: - tnode = TypedNode(self.ctx.sub_prefix(self.ctx.prop, key), jt, ojt) + tnode = TypedNode( + self.source, + key, + self.path, + self.prefix, + (jt, ojt), + ) self._children.append(tnode) -class StampableNode(ABC): +class StampableTable(ABC): @property @abstractmethod # The Callable construct is necessary until DuckDB implements CTAS RETURNING def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: ... -class RootNode(ObjectNode, StampableNode): +class RootNode(ObjectNode, StampableTable): def __init__( self, source: sql.Identifier, get_output_table: Callable[[str | None], tuple[str, sql.Identifier]], ): super().__init__( - NodeContext( - source, - sql.Identifier("jsonb"), - [], - None, - ), + source, + None, + sql.Identifier("jsonb"), None, ) self.get_output_table = get_output_table @@ -357,7 +377,7 @@ def create_root_table(source_stmt: sql.SQL) -> sql.Composed: sql.SQL(""" CREATE TABLE {output_table} AS WITH root_source AS (""").format(output_table=output_table) - + source_stmt.format(source_table=self.ctx.source) + + source_stmt.format(source_table=self.source) + sql.SQL( """) SELECT @@ -380,9 +400,15 @@ def create_root_table(source_stmt: sql.SQL) -> sql.Composed: return (output_table_name, create_root_table) -class ArrayNode(RecursiveNode, StampableNode): - def __init__(self, ctx: NodeContext, parent: RecursiveNode | None): - super().__init__(ctx, parent) +class ArrayNode(RecursiveNode, StampableTable): + def __init__( + self, + source: sql.Identifier, + prop: str | None, + column: sql.Identifier, + parent: RecursiveNode | None, + ): + super().__init__(source, prop, column, parent) self.temp = sql.Identifier(str(uuid4()).split("-")[0]) def make_temp(self, conn: Conn) -> Node | None: @@ -405,14 +431,14 @@ def make_temp(self, conn: Conn) -> Node | None: FROM ( SELECT """) - + self.json_value + + self.path + sql.SQL(""" AS ld_value, __id FROM {source} ) j CROSS JOIN LATERAL jsonb_array_elements(j.ld_value) WITH ORDINALITY a("value", ord) WHERE jsonb_typeof(j.ld_value) = 'array' ) expansion WHERE json_type <> 'null' -""").format(source=self.ctx.source) +""").format(source=self.source) ) cur.execute(expansion.as_string()) @@ -423,52 +449,42 @@ def make_temp(self, conn: Conn) -> Node | None: FROM {temp}""").format(temp=self.temp) cur.execute(type_discovery.as_string()) - self._children.append( - OrdinalNode(self.ctx.array_prefix(self.temp, self.ctx.prop)), - ) + self._children.append(OrdinalNode(self.temp, self.path, self.prefix)) if row := cur.fetchone(): (jt, ojt) = cast("tuple[JsonType, JsonType]", row) node: Node if jt == "array" and ojt == "array": - node = ArrayNode( - self.ctx.array_prefix(self.temp, self.ctx.prop), - self, - ) - self._children.append(node) + node = ArrayNode(self.temp, None, self.column, self) elif jt == "object" and ojt == "object": - node = ObjectNode( - self.ctx.array_prefix(self.temp, self.ctx.prop), - self, - ) - self._children.append(node) + node = ObjectNode(self.temp, None, self.column, self) else: node = TypedNode( - self.ctx.array_prefix(self.temp, self.ctx.prop), - jt, - ojt, + self.temp, + None, + sql.Identifier("jsonb"), + self.prefix, + (jt, ojt), ) - self._children.append(node) - + self._children.append(node) return node return None @property def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: - p: RecursiveNode | None = self + table_parent: RecursiveNode = self parents: list[RecursiveNode] = [] - while p is not None and (p := p.parent): + for p in self.parents: + if table_parent == self and isinstance(table_parent, (RootNode, ArrayNode)): + table_parent = p parents.append(p) + root = cast("RootNode", parents[-1]) - (output_table_name, output_table) = root.get_output_table( - "__".join(self.ctx.prefixes) + (self.ctx.prop or ""), - ) + (output_table_name, output_table) = root.get_output_table(self.prefix) if parents[0] == parents[-1]: (_, parent_table) = root.get_output_table(None) else: - (_, parent_table) = root.get_output_table( - "__" + "__".join(cast("Node", parents[0]).ctx.prefixes), - ) + (_, parent_table) = root.get_output_table(table_parent.prefix) def create_array_table(source_stmt: sql.SQL) -> sql.Composed: return ( From a7dc2f22f2bd565af5a29516074f6195af822681 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Wed, 25 Mar 2026 20:05:40 +0000 Subject: [PATCH 21/28] Implement json_depth --- src/ldlite/database/_expansion/rewrite.py | 141 ++++++++++++++++------ src/ldlite/database/_typed_database.py | 1 + 2 files changed, 107 insertions(+), 35 deletions(-) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/rewrite.py index 86a51a5..30912e0 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/rewrite.py @@ -16,7 +16,7 @@ from tqdm import tqdm Conn: TypeAlias = duckdb.DuckDBPyConnection | psycopg.Connection -JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean"] +JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean", "jsonb"] class Node: @@ -98,7 +98,9 @@ def json_string(self) -> sql.Composable: @property def stmt(self) -> sql.Composable: type_extract: sql.Composable - if self.json_type == "number" and self.is_float: + if self.json_type == "jsonb": + type_extract = self.path + elif self.json_type == "number" and self.is_float: type_extract = self.json_string + sql.SQL("::numeric") elif self.json_type == "number" and self.is_bigint: type_extract = self.json_string + sql.SQL("::bigint") @@ -116,7 +118,7 @@ def stmt(self) -> sql.Composable: return type_extract + sql.SQL(" AS ") + sql.Identifier(self.alias) def specify_type(self, conn: Conn) -> None: - if self.is_mixed or self.json_type == "boolean": + if self.is_mixed or self.json_type not in ["string", "number"]: return cte = ( @@ -152,10 +154,10 @@ def specify_type(self, conn: Conn) -> None: cur.execute(specify.as_string()) if row := cur.fetchone(): (self.is_uuid, self.is_datetime) = row - return - with conn.cursor() as cur: - specify = cte + sql.SQL(""" + if self.json_type == "number": + with conn.cursor() as cur: + specify = cte + sql.SQL(""" SELECT EXISTS( SELECT 1 FROM string_values @@ -170,11 +172,9 @@ def specify_type(self, conn: Conn) -> None: string_value::numeric > 2147483647 ) AS is_bigint;""") - cur.execute(specify.as_string()) - if row := cur.fetchone(): - (self.is_float, self.is_bigint) = row - else: - self.json_type = "string" + cur.execute(specify.as_string()) + if row := cur.fetchone(): + (self.is_float, self.is_bigint) = row class OrdinalNode(FixedValueNode): @@ -197,6 +197,16 @@ def stmt(self) -> sql.Composable: ) +class JsonbNode(TypedNode): + def __init__( + self, + source: sql.Identifier, + path: sql.Composable, + prefix: str, + ): + super().__init__(source, None, path, prefix, ("jsonb", "jsonb")) + + TNode = TypeVar("TNode", bound="Node") TRode = TypeVar("TRode", bound="RecursiveNode") @@ -234,6 +244,27 @@ def table_parent(self) -> RecursiveNode: # There's "always" a root node return None # type: ignore[return-value] + @property + def depth(self) -> int: + depth = 0 + prev = None + for p in self.parents: + # arrays of objects only count for a single level of depth + if not (isinstance(p, ObjectNode) and isinstance(prev, ArrayNode)): + depth += 1 + prev = p + + return depth + + def replace(self, original: Node, replacement: Node) -> None: + self._children = [(replacement if n == original else n) for n in self._children] + + def make_jsonb(self) -> None: + cast("RecursiveNode", self.parent).replace( + self, + JsonbNode(self.source, self.path, self.prefix), + ) + @property def path(self) -> sql.Composable: prop_accessor: sql.Composable @@ -274,19 +305,31 @@ def _direct(self, cls: type[TNode]) -> Iterator[TNode]: def direct(self, cls: type[TNode]) -> list[TNode]: return list(self._direct(cls)) - def _descendents(self, cls: type[TRode]) -> Iterator[TRode]: - to_visit = deque([self]) + def _descendents( + self, + cls: type[TRode], + to_cls: type[TRode] | None = None, + ) -> Iterator[TRode]: + to_visit = deque(self.direct(RecursiveNode)) while to_visit: n = to_visit.pop() if isinstance(n, cls): yield n + if to_cls is not None and isinstance(n, to_cls): + continue + to_visit.extend(n.direct(RecursiveNode)) - def descendents(self, cls: type[TRode]) -> list[TRode]: - return list(self._descendents(cls)) + def descendents( + self, + cls: type[TRode], + to_cls: type[TRode] | None = None, + ) -> list[TRode]: + return list(self._descendents(cls, to_cls)) def _typed_nodes(self) -> Iterator[TypedNode]: + yield from self.direct(TypedNode) for n in self._descendents(RecursiveNode): yield from n.direct(TypedNode) @@ -386,9 +429,10 @@ def create_root_table(source_stmt: sql.SQL) -> sql.Composed: + sql.SQL("\n ,").join( [ sql.Identifier("__id"), + *[t.stmt for t in self.direct(TypedNode)], *[ t.stmt - for o in self.descendents(ObjectNode) + for o in self.descendents(ObjectNode, ArrayNode) for t in o.direct(TypedNode) ], ], @@ -420,12 +464,12 @@ def make_temp(self, conn: Conn) -> Node | None: __id AS p__id ,(ROW_NUMBER() OVER (ORDER BY (SELECT NULL)))::integer AS __id ,ord::smallint AS __o - ,jsonb + ,array_jsonb ,json_type FROM ( SELECT j.__id - ,a."value" AS jsonb + ,a."value" AS array_jsonb ,jsonb_typeof(a."value") AS json_type ,a.ord FROM @@ -454,14 +498,24 @@ def make_temp(self, conn: Conn) -> Node | None: (jt, ojt) = cast("tuple[JsonType, JsonType]", row) node: Node if jt == "array" and ojt == "array": - node = ArrayNode(self.temp, None, self.column, self) + node = ArrayNode( + self.temp, + None, + sql.Identifier("array_jsonb"), + self, + ) elif jt == "object" and ojt == "object": - node = ObjectNode(self.temp, None, self.column, self) + node = ObjectNode( + self.temp, + None, + sql.Identifier("array_jsonb"), + self, + ) else: node = TypedNode( self.temp, None, - sql.Identifier("jsonb"), + sql.Identifier("array_jsonb"), self.prefix, (jt, ojt), ) @@ -475,13 +529,13 @@ def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: table_parent: RecursiveNode = self parents: list[RecursiveNode] = [] for p in self.parents: - if table_parent == self and isinstance(table_parent, (RootNode, ArrayNode)): + if table_parent == self and isinstance(p, (RootNode, ArrayNode)): table_parent = p parents.append(p) root = cast("RootNode", parents[-1]) (output_table_name, output_table) = root.get_output_table(self.prefix) - if parents[0] == parents[-1]: + if not isinstance(table_parent, ArrayNode): (_, parent_table) = root.get_output_table(None) else: (_, parent_table) = root.get_output_table(table_parent.prefix) @@ -503,12 +557,12 @@ def create_array_table(source_stmt: sql.SQL) -> sql.Composed: *[ sql.Identifier("p", t.alias) for p in reversed(parents) - for t in p.direct(TypedNode) + for t in p.direct(FixedValueNode) ], *[t.stmt for t in self.direct(FixedValueNode)], *[ t.stmt - for o in self.descendents(ObjectNode) + for o in self.descendents(ObjectNode, ArrayNode) for t in o.direct(TypedNode) ], ], @@ -527,6 +581,7 @@ def _non_srs_statements( conn: Conn, source_table: sql.Identifier, output_table: Callable[[str | None], tuple[str, sql.Identifier]], + json_depth: int, scan_progress: tqdm[NoReturn], ) -> Iterator[tuple[str, Callable[[sql.SQL], sql.Composed]]]: # Here be dragons! The nodes have inner state manipulations @@ -539,27 +594,34 @@ def _non_srs_statements( # Because building up to the transformation statements takes a long time # we're doing all that work up front to keep the time that # a transaction is opened to a minimum (which is a leaky abstraction). + scan_progress.total = scan_progress.total if scan_progress.total is not None else 1 root = RootNode(source_table, output_table) onodes: deque[ObjectNode] = deque([root]) while onodes: o = onodes.popleft() - o.load_columns(conn) - scan_progress.total += len(o.direct(Node)) + if o.depth < json_depth: + o.load_columns(conn) + scan_progress.total += len(o.direct(Node)) + else: + o.make_jsonb() scan_progress.update(1) onodes.extend(o.direct(ObjectNode)) anodes = deque(o.direct(ArrayNode)) while anodes: a = anodes.popleft() - if n := a.make_temp(conn): - if isinstance(n, ObjectNode): - onodes.append(n) - if isinstance(n, ArrayNode): - anodes.append(n) - scan_progress.total += 1 + if a.depth < json_depth: + if n := a.make_temp(conn): + if isinstance(n, ObjectNode): + onodes.append(n) + if isinstance(n, ArrayNode): + anodes.append(n) + scan_progress.total += 1 + else: + cast("RecursiveNode", a.parent).remove(a) else: - cast("RecursiveNode", a.parent).remove(a) + a.make_jsonb() scan_progress.update(1) @@ -576,6 +638,15 @@ def non_srs_statements( conn: Conn, source_table: sql.Identifier, output_table: Callable[[str | None], tuple[str, sql.Identifier]], + json_depth: int, scan_progress: tqdm[NoReturn], ) -> list[tuple[str, Callable[[sql.SQL], sql.Composed]]]: - return list(_non_srs_statements(conn, source_table, output_table, scan_progress)) + return list( + _non_srs_statements( + conn, + source_table, + output_table, + json_depth, + scan_progress, + ), + ) diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index 9f412e1..3b106f6 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -225,6 +225,7 @@ def expand_prefix( conn, pfx.raw_table[1], pfx.output_table, + json_depth, scan_progress if scan_progress is not None else tqdm(disable=True, total=0), From 25523640b3bb103ccedb58a95653ecb0f35c553f Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Wed, 25 Mar 2026 20:18:06 +0000 Subject: [PATCH 22/28] Replace legacy version with the rewrite --- src/ldlite/database/_expansion/__init__.py | 245 ++++------ src/ldlite/database/_expansion/context.py | 52 --- src/ldlite/database/_expansion/fixed_nodes.py | 181 ++++++++ src/ldlite/database/_expansion/metadata.py | 141 ------ src/ldlite/database/_expansion/node.py | 33 ++ src/ldlite/database/_expansion/nodes.py | 435 ------------------ .../{rewrite.py => recursive_nodes.py} | 274 +---------- src/ldlite/database/_typed_database.py | 2 +- 8 files changed, 302 insertions(+), 1061 deletions(-) delete mode 100644 src/ldlite/database/_expansion/context.py create mode 100644 src/ldlite/database/_expansion/fixed_nodes.py delete mode 100644 src/ldlite/database/_expansion/metadata.py create mode 100644 src/ldlite/database/_expansion/node.py delete mode 100644 src/ldlite/database/_expansion/nodes.py rename src/ldlite/database/_expansion/{rewrite.py => recursive_nodes.py} (58%) diff --git a/src/ldlite/database/_expansion/__init__.py b/src/ldlite/database/_expansion/__init__.py index 0d1fe41..e11620b 100644 --- a/src/ldlite/database/_expansion/__init__.py +++ b/src/ldlite/database/_expansion/__init__.py @@ -1,167 +1,90 @@ from __future__ import annotations from collections import deque -from typing import TYPE_CHECKING - -from psycopg import sql - -from .nodes import ArrayNode, ObjectNode +from typing import TYPE_CHECKING, cast if TYPE_CHECKING: - from .context import ExpandContext - - -def expand_nonmarc( - root_name: str, - root_values: list[str], - ctx: ExpandContext, -) -> list[tuple[str, sql.Composed]]: - (_, tables_to_create) = _expand_nonmarc( - ObjectNode(root_name, "", None, root_values), - 0, - ctx, - ) - return tables_to_create - - -def _expand_nonmarc( # noqa: PLR0915 - root: ObjectNode, - count: int, - ctx: ExpandContext, -) -> tuple[int, list[tuple[str, sql.Composed]]]: - ctx.scan_progress.total = (ctx.scan_progress.total or 0) + 1 - ctx.scan_progress.refresh() - ctx.transform_progress.total = (ctx.transform_progress.total or 0) + 1 - ctx.transform_progress.refresh() - initial_count = count - ctx.preprocess(ctx.conn, ctx.source_table, [root.identifier]) - has_rows = root.unnest( - ctx, - ctx.source_table, - ctx.get_transform_table(count), - ctx.source_cte(False), - ) - ctx.transform_progress.update(1) - if not has_rows: - return (0, []) - - with ctx.conn.cursor() as cur: - cur.execute( - sql.SQL("DROP TABLE {previous_table}") - .format(previous_table=ctx.source_table) - .as_string(), - ) - - expand_children_of = deque([root]) - while expand_children_of: - on = expand_children_of.popleft() - if ctx.transform_progress: - ctx.transform_progress.total += len(on.object_children) - ctx.transform_progress.refresh() - for c in on.object_children: - if len(c.parents) >= ctx.json_depth: - if c.parent is not None: - c.parent.values.append(c.name) - continue - ctx.preprocess(ctx.conn, ctx.get_transform_table(count), [c.identifier]) - c.unnest( - ctx, - ctx.get_transform_table(count), - ctx.get_transform_table(count + 1), - ctx.source_cte(False), - ) - with ctx.conn.cursor() as cur: - cur.execute( - sql.SQL("DROP TABLE {previous_table}") - .format(previous_table=ctx.get_transform_table(count)) - .as_string(), - ) - expand_children_of.append(c) - count += 1 - ctx.transform_progress.update(1) - - tables_to_create = [] - - new_source_table = ctx.get_transform_table(count) - arrays = root.descendents_oftype(ArrayNode) - ctx.transform_progress.total += len(arrays) - ctx.transform_progress.refresh() - ctx.preprocess(ctx.conn, new_source_table, [a.identifier for a in arrays]) - for an in arrays: - if len(an.parents) >= ctx.json_depth: - continue - values = an.explode( - ctx.conn, - new_source_table, - ctx.get_transform_table(count + 1), - ctx.source_cte(True), - ) - count += 1 - ctx.transform_progress.update(1) - - if an.is_object: - (sub_index, array_tables) = _expand_nonmarc( - ObjectNode( - an.name, - an.name, - None, - values, - ), - count + 1, - ctx.array_context( - ctx.get_transform_table(count), - ctx.json_depth - len(an.parents), - ), - ) - count += sub_index - tables_to_create.extend(array_tables) + from collections.abc import Callable, Iterator + from typing import NoReturn + + from psycopg import sql + from tqdm import tqdm + + +from .node import Conn, Node +from .recursive_nodes import ArrayNode, ObjectNode, RecursiveNode, RootNode + + +def _non_srs_statements( + conn: Conn, + source_table: sql.Identifier, + output_table: Callable[[str | None], tuple[str, sql.Identifier]], + json_depth: int, + scan_progress: tqdm[NoReturn], +) -> Iterator[tuple[str, Callable[[sql.SQL], sql.Composed]]]: + # Here be dragons! The nodes have inner state manipulations + # that violate the space/time continuum: + # * o.load_columns + # * a.make_temp + # * t.specify_type + # These all are expected to be called before generating the sql + # as they load/prepare database information. + # Because building up to the transformation statements takes a long time + # we're doing all that work up front to keep the time that + # a transaction is opened to a minimum (which is a leaky abstraction). + scan_progress.total = scan_progress.total if scan_progress.total is not None else 1 + + root = RootNode(source_table, output_table) + onodes: deque[ObjectNode] = deque([root]) + while onodes: + o = onodes.popleft() + if o.depth < json_depth: + o.load_columns(conn) + scan_progress.total += len(o.direct(Node)) else: - with ctx.conn.cursor() as cur: - (tname, tid) = ctx.get_output_table(an.name) - tables_to_create.append( - ( - tname, - sql.SQL( - """ -CREATE TABLE {dest_table} AS -""" - + ctx.source_cte(False) - + """ -SELECT {cols} FROM ld_source -""", - ).format( - dest_table=tid, - source_table=ctx.get_transform_table(count), - cols=sql.SQL("\n ,").join( - [sql.Identifier(v) for v in [*values, an.name]], - ), - ), - ), - ) - - stamped_values = [ - sql.Identifier(v) for n in root.descendents if n not in arrays for v in n.values - ] - - with ctx.conn.cursor() as cur: - (tname, tid) = ctx.get_output_table(root.path) - tables_to_create.append( - ( - tname, - sql.SQL( - """ -CREATE TABLE {dest_table} AS -""" - + ctx.source_cte(False) - + """ -SELECT {cols} FROM ld_source -""", - ).format( - dest_table=tid, - source_table=new_source_table, - cols=sql.SQL("\n ,").join(stamped_values), - ), - ), - ) - - return (count + 1 - initial_count, tables_to_create) + o.make_jsonb() + scan_progress.update(1) + + onodes.extend(o.direct(ObjectNode)) + anodes = deque(o.direct(ArrayNode)) + while anodes: + a = anodes.popleft() + if a.depth < json_depth: + if n := a.make_temp(conn): + if isinstance(n, ObjectNode): + onodes.append(n) + if isinstance(n, ArrayNode): + anodes.append(n) + scan_progress.total += 1 + else: + cast("RecursiveNode", a.parent).remove(a) + else: + a.make_jsonb() + + scan_progress.update(1) + + for t in root.typed_nodes(): + t.specify_type(conn) + scan_progress.update(1) + + yield root.create_statement + for a in root.descendents(ArrayNode): + yield a.create_statement + + +def non_srs_statements( + conn: Conn, + source_table: sql.Identifier, + output_table: Callable[[str | None], tuple[str, sql.Identifier]], + json_depth: int, + scan_progress: tqdm[NoReturn], +) -> list[tuple[str, Callable[[sql.SQL], sql.Composed]]]: + return list( + _non_srs_statements( + conn, + source_table, + output_table, + json_depth, + scan_progress, + ), + ) diff --git a/src/ldlite/database/_expansion/context.py b/src/ldlite/database/_expansion/context.py deleted file mode 100644 index 7b9e29a..0000000 --- a/src/ldlite/database/_expansion/context.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING, NoReturn - -if TYPE_CHECKING: - from collections.abc import Callable - - import duckdb - import psycopg - from psycopg import sql - from tqdm import tqdm - - -@dataclass -class ExpandContext: - conn: duckdb.DuckDBPyConnection | psycopg.Connection - source_table: sql.Identifier - json_depth: int - get_transform_table: Callable[[int], sql.Identifier] - get_output_table: Callable[[str], tuple[str, sql.Identifier]] - # This is necessary for Analyzing the table in pg before querying it - # I don't love how this is implemented - preprocess: Callable[ - [ - duckdb.DuckDBPyConnection | psycopg.Connection, - sql.Identifier, - list[sql.Identifier], - ], - None, - ] - # source_cte will go away when DuckDB implements CTAS RETURNING - source_cte: Callable[[bool], str] - scan_progress: tqdm[NoReturn] - transform_progress: tqdm[NoReturn] - - def array_context( - self, - new_source_table: sql.Identifier, - new_json_depth: int, - ) -> ExpandContext: - return ExpandContext( - self.conn, - new_source_table, - new_json_depth, - self.get_transform_table, - self.get_output_table, - self.preprocess, - self.source_cte, - self.scan_progress, - self.transform_progress, - ) diff --git a/src/ldlite/database/_expansion/fixed_nodes.py b/src/ldlite/database/_expansion/fixed_nodes.py new file mode 100644 index 0000000..3dbb505 --- /dev/null +++ b/src/ldlite/database/_expansion/fixed_nodes.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from psycopg import sql + +if TYPE_CHECKING: + from typing import TypeAlias + + +from .node import Conn, Node + +JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean", "jsonb"] + + +class FixedValueNode(Node): + def __init__( + self, + source: sql.Identifier, + prop: str | None, + path: sql.Composable, + prefix: str, + ): + super().__init__(source, prop) + + self.path = path + self.prefix = prefix + + @property + def alias(self) -> str: + if len(self.prefix) == 0: + return self.snake_prop or "" + return self.prefix + ( + ("__" + self.snake_prop) if self.snake_prop is not None else "" + ) + + @property + def stmt(self) -> sql.Composable: + # this should be abstract but Python can't use ABCs as a generic + return sql.SQL("") + + +class TypedNode(FixedValueNode): + def __init__( + self, + source: sql.Identifier, + prop: str | None, + path: sql.Composable, + prefix: str, + json_types: tuple[JsonType, JsonType], + ): + super().__init__(source, prop, path, prefix) + + self.is_mixed = json_types[0] != json_types[1] + self.json_type: JsonType = "string" if self.is_mixed else json_types[0] + self.is_uuid = False + self.is_datetime = False + self.is_float = False + self.is_bigint = False + + @property + def json_string(self) -> sql.Composable: + if self.prop is None: + str_extract = ( + sql.SQL("""TRIM(BOTH '"' FROM (""") + self.path + sql.SQL(")::text)") + ) + else: + str_extract = self.path + sql.SQL("->>") + sql.Literal(self.prop) + + return sql.SQL("NULLIF(NULLIF(") + str_extract + sql.SQL(", ''), 'null')") + + @property + def stmt(self) -> sql.Composable: + type_extract: sql.Composable + if self.json_type == "jsonb": + type_extract = self.path + elif self.json_type == "number" and self.is_float: + type_extract = self.json_string + sql.SQL("::numeric") + elif self.json_type == "number" and self.is_bigint: + type_extract = self.json_string + sql.SQL("::bigint") + elif self.json_type == "number": + type_extract = self.json_string + sql.SQL("::integer") + elif self.json_type == "boolean": + type_extract = self.json_string + sql.SQL("::bool") + elif self.json_type == "string" and self.is_uuid: + type_extract = self.json_string + sql.SQL("::uuid") + elif self.json_type == "string" and self.is_datetime: + type_extract = self.json_string + sql.SQL("::timestamptz") + else: + type_extract = self.json_string + + return type_extract + sql.SQL(" AS ") + sql.Identifier(self.alias) + + def specify_type(self, conn: Conn) -> None: + if self.is_mixed or self.json_type not in ["string", "number"]: + return + + cte = ( + sql.SQL(""" +WITH string_values AS MATERIALIZED ( + SELECT """) + + self.json_string + + sql.SQL(""" AS string_value + FROM {source} +)""").format(source=self.source) + ) + + if self.json_type == "string": + with conn.cursor() as cur: + specify = cte + sql.SQL(""" +SELECT + NOT EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + string_value NOT LIKE '________-____-____-____-____________' + ) AS is_uuid + ,NOT EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + ( + string_value NOT LIKE '____-__-__T__:__:__.___' AND + string_value NOT LIKE '____-__-__T__:__:__.___+__:__' + ) + ) AS is_uuid;""") + + cur.execute(specify.as_string()) + if row := cur.fetchone(): + (self.is_uuid, self.is_datetime) = row + + if self.json_type == "number": + with conn.cursor() as cur: + specify = cte + sql.SQL(""" +SELECT + EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + string_value::numeric % 1 <> 0 + ) AS is_float + ,EXISTS( + SELECT 1 FROM string_values + WHERE + string_value IS NOT NULL AND + string_value::numeric > 2147483647 + ) AS is_bigint;""") + + cur.execute(specify.as_string()) + if row := cur.fetchone(): + (self.is_float, self.is_bigint) = row + + +class OrdinalNode(FixedValueNode): + def __init__( + self, + source: sql.Identifier, + path: sql.Composable, + prefix: str, + ): + super().__init__(source, None, path, prefix) + + @property + def alias(self) -> str: + return self.prefix + "__o" + + @property + def stmt(self) -> sql.Composable: + return sql.Identifier("a", "__o") + sql.SQL(" AS {alias}").format( + alias=sql.Identifier(self.alias), + ) + + +class JsonbNode(TypedNode): + def __init__( + self, + source: sql.Identifier, + path: sql.Composable, + prefix: str, + ): + super().__init__(source, None, path, prefix, ("jsonb", "jsonb")) diff --git a/src/ldlite/database/_expansion/metadata.py b/src/ldlite/database/_expansion/metadata.py deleted file mode 100644 index 7fbc2ee..0000000 --- a/src/ldlite/database/_expansion/metadata.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Literal - -from psycopg import sql - - -class Metadata(ABC): - def __init__(self, prop: str | None): - self.prop = prop - - @property - def snake(self) -> str: - if self.prop is None: - # this doesn't really come up in practice - return "$" - - snake = "".join("_" + c.lower() if c.isupper() else c for c in self.prop) - - # there's also sorts of weird edge cases here that don't come up in practice - if (naked := self.prop.lstrip("_")) and len(naked) > 0 and naked[0].isupper(): - snake = snake.removeprefix("_") - - return snake - - @property - @abstractmethod - def select_stmt(self) -> str: ... - - def select_column( - self, - json_col: sql.Identifier, - alias: str, - ) -> sql.Composed: - return sql.SQL(self.select_stmt + " AS {alias}").format( - json_col=json_col, - prop=self.prop, - alias=sql.Identifier(alias), - ) - - -class ObjectMeta(Metadata): - @property - def select_stmt(self) -> str: - return "{json_col}" if self.prop is None else "{json_col}->{prop}" - - -class ArrayMeta(Metadata): - @property - def select_stmt(self) -> str: - return "{json_col}" if self.prop is None else "{json_col}->{prop}" - - @abstractmethod - def unwrap(self) -> ObjectMeta | TypedMeta: ... - - -class TypedMeta(Metadata): - def __init__( # noqa: PLR0913 - self, - prop: str | None, - json_type: Literal["string", "number", "boolean"], - other_json_type: Literal["string", "number", "boolean"], - is_uuid: bool, - is_datetime: bool, - is_float: bool, - is_bigint: bool, - ): - super().__init__(prop) - - mixed_type = json_type != other_json_type - self.json_type: Literal["string", "number", "boolean"] = ( - json_type if not mixed_type else "string" - ) - self.is_uuid = is_uuid and not mixed_type - self.is_datetime = is_datetime and not mixed_type - self.is_float = is_float and not mixed_type - self.is_bigint = is_bigint and not mixed_type - - @property - def select_stmt(self) -> str: # noqa: PLR0911 - str_extract = ( - "{json_col}->>{prop}" - if self.prop is not None - else """TRIM(BOTH '"' FROM ({json_col})::text)""" - ) - str_extract = f"NULLIF(NULLIF({str_extract}, ''), 'null')" - - if self.json_type == "number" and self.is_float: - return f"{str_extract}::numeric" - if self.json_type == "number" and self.is_bigint: - return f"{str_extract}::bigint" - if self.json_type == "number": - return f"{str_extract}::integer" - if self.json_type == "boolean": - return f"{str_extract}::bool" - if self.json_type == "string" and self.is_uuid: - return f"{str_extract}::uuid" - if self.json_type == "string" and self.is_datetime: - return f"{str_extract}::timestamptz" - - return str_extract - - -class MixedMeta(TypedMeta): - def __init__( - self, - prop: str | None, - ): - super().__init__(prop, "string", "string", False, False, False, False) - - -class ObjectArrayMeta(ObjectMeta, ArrayMeta): - def unwrap(self) -> ObjectMeta: - return ObjectMeta(None) - - -class MixedArrayMeta(MixedMeta, ArrayMeta): - @property - def select_stmt(self) -> str: - return "{json_col}" if self.prop is None else "{json_col}->{prop}" - - def unwrap(self) -> MixedMeta: - return MixedMeta(None) - - -class TypedArrayMeta(TypedMeta, ArrayMeta): - @property - def select_stmt(self) -> str: - return "{json_col}" if self.prop is None else "{json_col}->{prop}" - - def unwrap(self) -> TypedMeta: - return TypedMeta( - None, - self.json_type, - self.json_type, - self.is_uuid, - self.is_datetime, - self.is_float, - self.is_bigint, - ) diff --git a/src/ldlite/database/_expansion/node.py b/src/ldlite/database/_expansion/node.py new file mode 100644 index 0000000..eb7d6be --- /dev/null +++ b/src/ldlite/database/_expansion/node.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import duckdb +import psycopg + +if TYPE_CHECKING: + from typing import TypeAlias + + from psycopg import sql + +Conn: TypeAlias = duckdb.DuckDBPyConnection | psycopg.Connection + + +class Node: + def __init__(self, source: sql.Identifier, prop: str | None): + self.source = source + self.prop = prop + self.snake_prop: str | None = None + + if self.prop is not None: + self.snake_prop = "".join( + "_" + c.lower() if c.isupper() else c for c in self.prop + ) + + # there's also sorts of weird edge cases here that don't come up in practice + if ( + (naked := self.prop.lstrip("_")) + and len(naked) > 0 + and naked[0].isupper() + ): + self.snake_prop = self.snake_prop.removeprefix("_") diff --git a/src/ldlite/database/_expansion/nodes.py b/src/ldlite/database/_expansion/nodes.py deleted file mode 100644 index 32626b5..0000000 --- a/src/ldlite/database/_expansion/nodes.py +++ /dev/null @@ -1,435 +0,0 @@ -from __future__ import annotations - -from collections import deque -from typing import TYPE_CHECKING, TypeVar, cast - -from psycopg import sql - -if TYPE_CHECKING: - from collections.abc import Iterator - - import duckdb - import psycopg - - from .context import ExpandContext - -from .metadata import ( - ArrayMeta, - Metadata, - MixedArrayMeta, - MixedMeta, - ObjectArrayMeta, - ObjectMeta, - TypedArrayMeta, - TypedMeta, -) - -TNode = TypeVar("TNode", bound="ExpansionNode") - - -class ExpansionNode: - def __init__( - self, - name: str, - path: str, - parent: ExpansionNode | None, - values: list[str] | None = None, - ): - self.name = name - self.path = path - self.identifier = sql.Identifier(name) - self.parent = parent - self.values: list[str] = values or [] - self.children: list[ExpansionNode] = [] - - def add(self, meta: Metadata) -> str: - snake = meta.snake - prefixed_name = self.prefix + snake - - if isinstance(meta, ArrayMeta): - self.children.append(ArrayNode(prefixed_name, snake, self, meta)) - elif isinstance(meta, ObjectMeta): - self.children.append(ObjectNode(prefixed_name, snake, self)) - else: - prefixed_name = self.prefix + snake - self.values.append(prefixed_name) - - return prefixed_name - - def _parents(self) -> Iterator[ExpansionNode]: - n = self - while n.parent is not None: - yield n.parent - n = n.parent - - @property - def parents(self) -> list[ExpansionNode]: - return list(self._parents()) - - @property - def prefix(self) -> str: - if len(self.parents) == 0 and len(self.path) == 0: - return "" - - return ( - "__".join( - [*[p.path for p in self.parents if len(p.path) != 0], self.path], - ) - + "__" - ) - - @property - def root(self) -> ExpansionNode: - if self.parent is None: - return self - - root = [p for p in self.parents if p.parent is None] - return root[0] - - def _descendents(self, cls: type[TNode]) -> Iterator[TNode]: - to_visit = deque([self]) - while to_visit: - n = to_visit.pop() - if isinstance(n, cls): - yield n - - to_visit.extend(n.children) - - @property - def descendents(self) -> list[ExpansionNode]: - return list(self._descendents(ExpansionNode)) - - def descendents_oftype(self, cls: type[TNode]) -> list[TNode]: - return list(self._descendents(cls)) - - def __str__(self) -> str: - return "->".join([n.name for n in reversed([self, *self.parents])]) - - -class ObjectNode(ExpansionNode): - def __init__( - self, - name: str, - path: str, - parent: ExpansionNode | None, - values: list[str] | None = None, - ): - super().__init__(name, path, parent, values) - self.unnested = False - - def _object_children(self) -> Iterator[ObjectNode]: - for c in self.children: - if isinstance(c, ObjectNode): - yield c - - @property - def object_children(self) -> list[ObjectNode]: - return list(self._object_children()) - - def unnest( - self, - ctx: ExpandContext, - source_table: sql.Identifier, - dest_table: sql.Identifier, - source_cte: str, - ) -> bool: - self.unnested = True - create_columns: list[sql.Composable] = [ - sql.Identifier(v) for v in self.carryover - ] - - with ctx.conn.cursor() as cur: - cur.execute( - sql.SQL("SELECT 1 FROM {table} LIMIT 1;") - .format(table=source_table) - .as_string(), - ) - if not cur.fetchone(): - return False - - with ctx.conn.cursor() as cur: - cur.execute( - sql.SQL( - """ -SELECT k.ld_key -FROM - {source_table} t - ,jsonb_object_keys(t.{json_col}) WITH ORDINALITY k(ld_key, "ordinality") -WHERE t.{json_col} IS NOT NULL AND jsonb_typeof(t.{json_col}) = 'object' -GROUP BY k.ld_key -ORDER BY MAX(k.ordinality), COUNT(k.ordinality) -""", - ) - .format(source_table=source_table, json_col=self.identifier) - .as_string(), - ) - props = [prop[0] for prop in cur.fetchall()] - - ctx.scan_progress.total += len(props) * 3 - ctx.scan_progress.refresh() - ctx.scan_progress.update(1) - - metadata: list[Metadata] = [] - for prop in props: - with ctx.conn.cursor() as cur: - cur.execute( - sql.SQL( - """ -SELECT - BOOL_AND(json_type = 'array') AS only_array - ,BOOL_OR(json_type = 'array') AS some_array - ,BOOL_AND(json_type = 'object') AS only_object - ,BOOL_OR(json_type = 'object') AS some_object -FROM -( - SELECT jsonb_typeof(t.{json_col}->$1) AS json_type - FROM {table} t -) j -WHERE json_type <> 'null' -""", - ) - .format( - table=source_table, - json_col=self.identifier, - ) - .as_string(), - (prop,), - ) - (only_array, some_array, only_object, some_object) = cast( - "tuple[bool, bool, bool, bool]", - cur.fetchone(), - ) - - if (some_array and not only_array) or (some_object and not only_object): - metadata.append(MixedMeta(prop)) - ctx.scan_progress.update(3) - continue - - if only_object: - metadata.append(ObjectMeta(prop)) - ctx.scan_progress.total += 1 - ctx.scan_progress.update(3) - continue - - if only_array: - ctx.scan_progress.update(1) - cur.execute( - sql.SQL( - """ -SELECT - -- Technically arrays could be nested but I haven't seen any - BOOL_AND(json_type = 'object') AS only_object - ,BOOL_OR(json_type = 'object') AS some_object -FROM -( - SELECT a.json_type - FROM {table} t - CROSS JOIN LATERAL - ( - SELECT jsonb_typeof(ld_value) AS json_type - FROM jsonb_array_elements(t.{json_col}->$1) a(ld_value) - WHERE jsonb_typeof(t.{json_col}->$1) = 'array' - ) a -) j -WHERE json_type <> 'null' -""", - ) - .format( - table=source_table, - json_col=self.identifier, - ) - .as_string(), - (prop,), - ) - (inner_only_object, inner_some_object) = cast( - "tuple[bool, bool]", - cur.fetchone(), - ) - - if inner_some_object and not inner_only_object: - metadata.append(MixedArrayMeta(prop)) - ctx.scan_progress.update(2) - continue - - if inner_only_object: - metadata.append(ObjectArrayMeta(prop)) - ctx.scan_progress.total += 1 - ctx.scan_progress.update(2) - continue - - ctx.scan_progress.update(1) - typed_from_sql = """ -FROM {table} t -CROSS JOIN LATERAL -( - SELECT * - FROM jsonb_array_elements(t.{json_col}->$1) a(ld_value) - WHERE jsonb_typeof(t.{json_col}->$1) = 'array' - LIMIT 3 -) j""" - else: - ctx.scan_progress.update(2) - typed_from_sql = """ -FROM (SELECT t.{json_col}->$1 AS ld_value FROM {table} t) j -""" - with ctx.conn.cursor() as cur: - cur.execute( - sql.SQL( - """ -SELECT - MIN(json_type) AS json_type - ,MAX(json_type) AS other_json_type - ,BOOL_AND(CASE WHEN json_type = 'string' THEN (ld_value)::text LIKE '"________-____-____-____-____________"' ELSE FALSE END) AS is_uuid - ,BOOL_AND(CASE WHEN json_type = 'string' THEN (ld_value)::text LIKE '"____-__-__T__:__:__.___%"' ELSE FALSE END) AS is_datetime - ,BOOL_OR(CASE WHEN json_type = 'number' THEN (ld_value)::numeric % 1 <> 0 ELSE FALSE END) AS is_float - ,BOOL_OR(CASE WHEN json_type = 'number' THEN (ld_value)::numeric > 2147483647 ELSE FALSE END) AS is_bigint -FROM -( - SELECT - ld_value - ,jsonb_typeof(ld_value) json_type """ # noqa: E501 - + typed_from_sql - + """ - WHERE ld_value IS NOT NULL -) i -WHERE - ld_value IS NOT NULL AND - json_type <> 'null' AND - ( - json_type <> 'string' OR - (json_type = 'string' AND ld_value::text NOT IN ('"null"', '""')) - ) -""", - ) - .format( - table=source_table, - json_col=self.identifier, - ) - .as_string(), - (prop,), - ) - if (row := cur.fetchone()) is not None and all( - c is not None for c in row - ): - metadata.append( - TypedArrayMeta(prop, *row) - if only_array - else TypedMeta(prop, *row), - ) - ctx.scan_progress.update(1) - - create_columns.extend( - [meta.select_column(self.identifier, self.add(meta)) for meta in metadata], - ) - - with ctx.conn.cursor() as cur: - cur.execute( - sql.SQL( - """ -CREATE TEMP TABLE {dest_table} AS -""" - + source_cte - + """ -SELECT - {cols} -FROM ld_source -""", - ) - .format( - source_table=source_table, - dest_table=dest_table, - json_col=self.identifier, - cols=sql.SQL("\n ,").join(create_columns), - ) - .as_string(), - ) - - return True - - def _carryover(self) -> Iterator[str]: - for n in self.root.descendents: - if isinstance(n, ObjectNode) and not n.unnested and n.name != "jsonb": - yield n.name - if isinstance(n, ArrayNode): - yield n.name - yield from n.values - - @property - def carryover(self) -> list[str]: - return list(self._carryover()) - - -class ArrayNode(ExpansionNode): - def __init__( - self, - name: str, - path: str, - parent: ExpansionNode | None, - meta: ArrayMeta, - values: list[str] | None = None, - ): - super().__init__(name, path, parent, values) - self.meta = meta.unwrap() - - @property - def is_object(self) -> bool: - return isinstance(self.meta, ObjectMeta) - - def explode( - self, - conn: duckdb.DuckDBPyConnection | psycopg.Connection, - source_table: sql.Identifier, - dest_table: sql.Identifier, - source_cte: str, - ) -> list[str]: - with conn.cursor() as cur: - o_col = self.name + "__o" - create_columns: list[sql.Composable] = [ - sql.SQL( - "(ROW_NUMBER() OVER (ORDER BY (SELECT NULL)))::integer AS __id", - ), - *[sql.Identifier(v) for v in self.carryover], - sql.SQL( - """a."ordinality"::smallint AS {id_alias}""", - ).format( - id_alias=sql.Identifier(o_col), - ), - self.meta.select_column( - sql.Identifier("a", "ld_value"), - self.name, - ), - ] - - cur.execute( - sql.SQL( - """ -CREATE TEMP TABLE {dest_table} AS -""" - + source_cte - + """ -SELECT - {cols} -FROM - ld_source s - ,jsonb_array_elements(s.{json_col}) WITH ORDINALITY a(ld_value, "ordinality") -WHERE jsonb_typeof(s.{json_col}) = 'array' -""", - ) - .format( - source_table=source_table, - dest_table=dest_table, - cols=sql.SQL("\n ,").join(create_columns), - json_col=sql.Identifier(self.name), - ) - .as_string(), - ) - - return ["__id", *self.carryover, o_col] - - def _carryover(self) -> Iterator[str]: - for n in reversed(self.parents): - yield from [v for v in n.values if v not in ("__id", "jsonb")] - - @property - def carryover(self) -> list[str]: - return list(self._carryover()) diff --git a/src/ldlite/database/_expansion/rewrite.py b/src/ldlite/database/_expansion/recursive_nodes.py similarity index 58% rename from src/ldlite/database/_expansion/rewrite.py rename to src/ldlite/database/_expansion/recursive_nodes.py index 30912e0..2e6c76f 100644 --- a/src/ldlite/database/_expansion/rewrite.py +++ b/src/ldlite/database/_expansion/recursive_nodes.py @@ -2,210 +2,17 @@ from abc import ABC, abstractmethod from collections import deque -from typing import TYPE_CHECKING, Literal, TypeVar, cast +from typing import TYPE_CHECKING, TypeVar, cast from uuid import uuid4 -import duckdb -import psycopg from psycopg import sql if TYPE_CHECKING: from collections.abc import Callable, Iterator - from typing import NoReturn, TypeAlias - from tqdm import tqdm - -Conn: TypeAlias = duckdb.DuckDBPyConnection | psycopg.Connection -JsonType: TypeAlias = Literal["array", "object", "string", "number", "boolean", "jsonb"] - - -class Node: - def __init__(self, source: sql.Identifier, prop: str | None): - self.source = source - self.prop = prop - self.snake_prop: str | None = None - - if self.prop is not None: - self.snake_prop = "".join( - "_" + c.lower() if c.isupper() else c for c in self.prop - ) - - # there's also sorts of weird edge cases here that don't come up in practice - if ( - (naked := self.prop.lstrip("_")) - and len(naked) > 0 - and naked[0].isupper() - ): - self.snake_prop = self.snake_prop.removeprefix("_") - - -class FixedValueNode(Node): - def __init__( - self, - source: sql.Identifier, - prop: str | None, - path: sql.Composable, - prefix: str, - ): - super().__init__(source, prop) - - self.path = path - self.prefix = prefix - - @property - def alias(self) -> str: - if len(self.prefix) == 0: - return self.snake_prop or "" - return self.prefix + ( - ("__" + self.snake_prop) if self.snake_prop is not None else "" - ) - - @property - def stmt(self) -> sql.Composable: - # this should be abstract but Python can't use ABCs as a generic - return sql.SQL("") - - -class TypedNode(FixedValueNode): - def __init__( - self, - source: sql.Identifier, - prop: str | None, - path: sql.Composable, - prefix: str, - json_types: tuple[JsonType, JsonType], - ): - super().__init__(source, prop, path, prefix) - - self.is_mixed = json_types[0] != json_types[1] - self.json_type: JsonType = "string" if self.is_mixed else json_types[0] - self.is_uuid = False - self.is_datetime = False - self.is_float = False - self.is_bigint = False - - @property - def json_string(self) -> sql.Composable: - if self.prop is None: - str_extract = ( - sql.SQL("""TRIM(BOTH '"' FROM (""") + self.path + sql.SQL(")::text)") - ) - else: - str_extract = self.path + sql.SQL("->>") + sql.Literal(self.prop) - - return sql.SQL("NULLIF(NULLIF(") + str_extract + sql.SQL(", ''), 'null')") - - @property - def stmt(self) -> sql.Composable: - type_extract: sql.Composable - if self.json_type == "jsonb": - type_extract = self.path - elif self.json_type == "number" and self.is_float: - type_extract = self.json_string + sql.SQL("::numeric") - elif self.json_type == "number" and self.is_bigint: - type_extract = self.json_string + sql.SQL("::bigint") - elif self.json_type == "number": - type_extract = self.json_string + sql.SQL("::integer") - elif self.json_type == "boolean": - type_extract = self.json_string + sql.SQL("::bool") - elif self.json_type == "string" and self.is_uuid: - type_extract = self.json_string + sql.SQL("::uuid") - elif self.json_type == "string" and self.is_datetime: - type_extract = self.json_string + sql.SQL("::timestamptz") - else: - type_extract = self.json_string - - return type_extract + sql.SQL(" AS ") + sql.Identifier(self.alias) - - def specify_type(self, conn: Conn) -> None: - if self.is_mixed or self.json_type not in ["string", "number"]: - return - - cte = ( - sql.SQL(""" -WITH string_values AS MATERIALIZED ( - SELECT """) - + self.json_string - + sql.SQL(""" AS string_value - FROM {source} -)""").format(source=self.source) - ) - - if self.json_type == "string": - with conn.cursor() as cur: - specify = cte + sql.SQL(""" -SELECT - NOT EXISTS( - SELECT 1 FROM string_values - WHERE - string_value IS NOT NULL AND - string_value NOT LIKE '________-____-____-____-____________' - ) AS is_uuid - ,NOT EXISTS( - SELECT 1 FROM string_values - WHERE - string_value IS NOT NULL AND - ( - string_value NOT LIKE '____-__-__T__:__:__.___' AND - string_value NOT LIKE '____-__-__T__:__:__.___+__:__' - ) - ) AS is_uuid;""") - - cur.execute(specify.as_string()) - if row := cur.fetchone(): - (self.is_uuid, self.is_datetime) = row - - if self.json_type == "number": - with conn.cursor() as cur: - specify = cte + sql.SQL(""" -SELECT - EXISTS( - SELECT 1 FROM string_values - WHERE - string_value IS NOT NULL AND - string_value::numeric % 1 <> 0 - ) AS is_float - ,EXISTS( - SELECT 1 FROM string_values - WHERE - string_value IS NOT NULL AND - string_value::numeric > 2147483647 - ) AS is_bigint;""") - - cur.execute(specify.as_string()) - if row := cur.fetchone(): - (self.is_float, self.is_bigint) = row - - -class OrdinalNode(FixedValueNode): - def __init__( - self, - source: sql.Identifier, - path: sql.Composable, - prefix: str, - ): - super().__init__(source, None, path, prefix) - - @property - def alias(self) -> str: - return self.prefix + "__o" - - @property - def stmt(self) -> sql.Composable: - return sql.Identifier("a", "__o") + sql.SQL(" AS {alias}").format( - alias=sql.Identifier(self.alias), - ) - - -class JsonbNode(TypedNode): - def __init__( - self, - source: sql.Identifier, - path: sql.Composable, - prefix: str, - ): - super().__init__(source, None, path, prefix, ("jsonb", "jsonb")) +from .fixed_nodes import FixedValueNode, JsonbNode, JsonType, OrdinalNode, TypedNode +from .node import Conn, Node TNode = TypeVar("TNode", bound="Node") TRode = TypeVar("TRode", bound="RecursiveNode") @@ -575,78 +382,3 @@ def create_array_table(source_stmt: sql.SQL) -> sql.Composed: ) return (output_table_name, create_array_table) - - -def _non_srs_statements( - conn: Conn, - source_table: sql.Identifier, - output_table: Callable[[str | None], tuple[str, sql.Identifier]], - json_depth: int, - scan_progress: tqdm[NoReturn], -) -> Iterator[tuple[str, Callable[[sql.SQL], sql.Composed]]]: - # Here be dragons! The nodes have inner state manipulations - # that violate the space/time continuum: - # * o.load_columns - # * a.make_temp - # * t.specify_type - # These all are expected to be called before generating the sql - # as they load/prepare database information. - # Because building up to the transformation statements takes a long time - # we're doing all that work up front to keep the time that - # a transaction is opened to a minimum (which is a leaky abstraction). - scan_progress.total = scan_progress.total if scan_progress.total is not None else 1 - - root = RootNode(source_table, output_table) - onodes: deque[ObjectNode] = deque([root]) - while onodes: - o = onodes.popleft() - if o.depth < json_depth: - o.load_columns(conn) - scan_progress.total += len(o.direct(Node)) - else: - o.make_jsonb() - scan_progress.update(1) - - onodes.extend(o.direct(ObjectNode)) - anodes = deque(o.direct(ArrayNode)) - while anodes: - a = anodes.popleft() - if a.depth < json_depth: - if n := a.make_temp(conn): - if isinstance(n, ObjectNode): - onodes.append(n) - if isinstance(n, ArrayNode): - anodes.append(n) - scan_progress.total += 1 - else: - cast("RecursiveNode", a.parent).remove(a) - else: - a.make_jsonb() - - scan_progress.update(1) - - for t in root.typed_nodes(): - t.specify_type(conn) - scan_progress.update(1) - - yield root.create_statement - for a in root.descendents(ArrayNode): - yield a.create_statement - - -def non_srs_statements( - conn: Conn, - source_table: sql.Identifier, - output_table: Callable[[str | None], tuple[str, sql.Identifier]], - json_depth: int, - scan_progress: tqdm[NoReturn], -) -> list[tuple[str, Callable[[sql.SQL], sql.Composed]]]: - return list( - _non_srs_statements( - conn, - source_table, - output_table, - json_depth, - scan_progress, - ), - ) diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index 3b106f6..4d7df17 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -11,7 +11,7 @@ from tqdm import tqdm from . import Database -from ._expansion.rewrite import non_srs_statements +from ._expansion import non_srs_statements from ._prefix import Prefix if TYPE_CHECKING: From 7d25e53fe26c10c25752b507d136c7ea2c611a65 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 26 Mar 2026 12:46:05 +0000 Subject: [PATCH 23/28] Remove delete from returning construct --- src/ldlite/database/_duckdb.py | 3 - src/ldlite/database/_expansion/__init__.py | 4 +- .../database/_expansion/recursive_nodes.py | 69 ++++++++----------- src/ldlite/database/_postgres.py | 6 -- src/ldlite/database/_typed_database.py | 13 +--- 5 files changed, 33 insertions(+), 62 deletions(-) diff --git a/src/ldlite/database/_duckdb.py b/src/ldlite/database/_duckdb.py index 1b5334c..f4268ea 100644 --- a/src/ldlite/database/_duckdb.py +++ b/src/ldlite/database/_duckdb.py @@ -86,9 +86,6 @@ def ingest_records( return total - def source_stmt(self, keep_source: bool) -> sql.SQL: # noqa: ARG002 - return sql.SQL("SELECT * FROM {source_table}") - # DuckDB has some strong opinions about cursors that are different than postgres # https://github.com/duckdb/duckdb/issues/11018 diff --git a/src/ldlite/database/_expansion/__init__.py b/src/ldlite/database/_expansion/__init__.py index e11620b..4527a58 100644 --- a/src/ldlite/database/_expansion/__init__.py +++ b/src/ldlite/database/_expansion/__init__.py @@ -21,7 +21,7 @@ def _non_srs_statements( output_table: Callable[[str | None], tuple[str, sql.Identifier]], json_depth: int, scan_progress: tqdm[NoReturn], -) -> Iterator[tuple[str, Callable[[sql.SQL], sql.Composed]]]: +) -> Iterator[tuple[str, sql.Composed]]: # Here be dragons! The nodes have inner state manipulations # that violate the space/time continuum: # * o.load_columns @@ -78,7 +78,7 @@ def non_srs_statements( output_table: Callable[[str | None], tuple[str, sql.Identifier]], json_depth: int, scan_progress: tqdm[NoReturn], -) -> list[tuple[str, Callable[[sql.SQL], sql.Composed]]]: +) -> list[tuple[str, sql.Composed]]: return list( _non_srs_statements( conn, diff --git a/src/ldlite/database/_expansion/recursive_nodes.py b/src/ldlite/database/_expansion/recursive_nodes.py index 2e6c76f..d03566b 100644 --- a/src/ldlite/database/_expansion/recursive_nodes.py +++ b/src/ldlite/database/_expansion/recursive_nodes.py @@ -200,8 +200,7 @@ def load_columns(self, conn: Conn) -> None: class StampableTable(ABC): @property @abstractmethod - # The Callable construct is necessary until DuckDB implements CTAS RETURNING - def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: ... + def create_statement(self) -> tuple[str, sql.Composed]: ... class RootNode(ObjectNode, StampableTable): @@ -219,36 +218,29 @@ def __init__( self.get_output_table = get_output_table @property - def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: + def create_statement(self) -> tuple[str, sql.Composed]: (output_table_name, output_table) = self.get_output_table(None) - def create_root_table(source_stmt: sql.SQL) -> sql.Composed: - return ( - sql.SQL(""" + return ( + output_table_name, + sql.SQL(""" CREATE TABLE {output_table} AS -WITH root_source AS (""").format(output_table=output_table) - + source_stmt.format(source_table=self.source) - + sql.SQL( - """) SELECT - """, - ) - + sql.SQL("\n ,").join( - [ - sql.Identifier("__id"), - *[t.stmt for t in self.direct(TypedNode)], - *[ - t.stmt - for o in self.descendents(ObjectNode, ArrayNode) - for t in o.direct(TypedNode) - ], + """).format(output_table=output_table) + + sql.SQL("\n ,").join( + [ + sql.Identifier("__id"), + *[t.stmt for t in self.direct(TypedNode)], + *[ + t.stmt + for o in self.descendents(ObjectNode, ArrayNode) + for t in o.direct(TypedNode) ], - ) - + sql.SQL(""" -FROM root_source""") + ], ) - - return (output_table_name, create_root_table) + + sql.SQL(""" +FROM {source_table}""").format(source_table=self.source), + ) class ArrayNode(RecursiveNode, StampableTable): @@ -332,7 +324,7 @@ def make_temp(self, conn: Conn) -> Node | None: return None @property - def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: + def create_statement(self) -> tuple[str, sql.Composed]: table_parent: RecursiveNode = self parents: list[RecursiveNode] = [] for p in self.parents: @@ -347,17 +339,15 @@ def create_statement(self) -> tuple[str, Callable[[sql.SQL], sql.Composed]]: else: (_, parent_table) = root.get_output_table(table_parent.prefix) - def create_array_table(source_stmt: sql.SQL) -> sql.Composed: - return ( - sql.SQL(""" + return ( + output_table_name, + ( + sql.SQL( + """ CREATE TABLE {output_table} AS -WITH array_source AS (""").format(output_table=output_table) - + source_stmt.format(source_table=self.temp) - + sql.SQL( - """) SELECT """, - ) + ).format(output_table=output_table) + sql.SQL("\n ,").join( [ sql.Identifier("a", "__id"), @@ -375,10 +365,9 @@ def create_array_table(source_stmt: sql.SQL) -> sql.Composed: ], ) + sql.SQL(""" -FROM array_source a +FROM {source_table} a JOIN {parent_table} p ON a.p__id = p.__id; -""").format(parent_table=parent_table) - ) - - return (output_table_name, create_array_table) +""").format(source_table=self.temp, parent_table=parent_table) + ), + ) diff --git a/src/ldlite/database/_postgres.py b/src/ldlite/database/_postgres.py index 3edcece..aebe436 100644 --- a/src/ldlite/database/_postgres.py +++ b/src/ldlite/database/_postgres.py @@ -102,9 +102,3 @@ def preprocess_source_table( column_name=sql.SQL(",").join(column_names), ), ) - - def source_stmt(self, keep_source: bool) -> sql.SQL: - if keep_source: - return sql.SQL("SELECT * FROM {source_table}") - - return sql.SQL("DELETE FROM {source_table} RETURNING *") diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index 4d7df17..df148fb 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -187,12 +187,6 @@ def preprocess_source_table( column_names: list[sql.Identifier], ) -> None: ... - # TODO: Refactor this to use DELETE RETURNING when DuckDb resolves - # https://github.com/duckdb/duckdb/issues/3417 - # Only postgres supports it which is why we have an abstraction here - @abstractmethod - def source_stmt(self, keep_source: bool) -> sql.SQL: ... - def expand_prefix( self, prefix: str, @@ -236,11 +230,8 @@ def expand_prefix( with self._begin(conn): self._drop_extracted_tables(conn, pfx) with conn.cursor() as cur: - for i, (_, table) in enumerate(tables_to_create): - create_table = table( - self.source_stmt(keep_source=(i == 0 and keep_raw)), - ) - cur.execute(create_table.as_string()) + for _, table in tables_to_create: + cur.execute(table.as_string()) transform_progress.update(1) # duckdb can't drop the raw table when creating the output table From 351c9fadb433f31a2fdff2d53fed05e0e8842a83 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 26 Mar 2026 12:53:51 +0000 Subject: [PATCH 24/28] Inline the ANALYZE statements necessary for the transform --- .../database/_expansion/recursive_nodes.py | 24 +++++++++++++------ src/ldlite/database/_postgres.py | 17 ------------- src/ldlite/database/_typed_database.py | 7 ------ 3 files changed, 17 insertions(+), 31 deletions(-) diff --git a/src/ldlite/database/_expansion/recursive_nodes.py b/src/ldlite/database/_expansion/recursive_nodes.py index d03566b..e7a32ce 100644 --- a/src/ldlite/database/_expansion/recursive_nodes.py +++ b/src/ldlite/database/_expansion/recursive_nodes.py @@ -239,7 +239,11 @@ def create_statement(self) -> tuple[str, sql.Composed]: ], ) + sql.SQL(""" -FROM {source_table}""").format(source_table=self.source), +FROM {source_table}; +ANALYZE {output_table} (__id);""").format( + source_table=self.source, + output_table=output_table, + ), ) @@ -257,8 +261,8 @@ def __init__( def make_temp(self, conn: Conn) -> Node | None: with conn.cursor() as cur: expansion = ( - sql.SQL("CREATE TEMPORARY TABLE {temp} AS").format(temp=self.temp) - + sql.SQL(""" + sql.SQL(""" +CREATE TEMPORARY TABLE {temp} AS SELECT __id AS p__id ,(ROW_NUMBER() OVER (ORDER BY (SELECT NULL)))::integer AS __id @@ -273,15 +277,16 @@ def make_temp(self, conn: Conn) -> Node | None: ,a.ord FROM ( - SELECT """) + SELECT """).format(temp=self.temp) + self.path + sql.SQL(""" AS ld_value, __id FROM {source} ) j CROSS JOIN LATERAL jsonb_array_elements(j.ld_value) WITH ORDINALITY a("value", ord) WHERE jsonb_typeof(j.ld_value) = 'array' ) expansion -WHERE json_type <> 'null' -""").format(source=self.source) +WHERE json_type <> 'null'; +ANALYZE {temp} (p__id, array_jsonb, json_type); +""").format(source=self.source, temp=self.temp) ) cur.execute(expansion.as_string()) @@ -368,6 +373,11 @@ def create_statement(self) -> tuple[str, sql.Composed]: FROM {source_table} a JOIN {parent_table} p ON a.p__id = p.__id; -""").format(source_table=self.temp, parent_table=parent_table) +ANALYZE {output_table} (__id); +""").format( + source_table=self.temp, + parent_table=parent_table, + output_table=output_table, + ) ), ) diff --git a/src/ldlite/database/_postgres.py b/src/ldlite/database/_postgres.py index aebe436..7f08175 100644 --- a/src/ldlite/database/_postgres.py +++ b/src/ldlite/database/_postgres.py @@ -85,20 +85,3 @@ def ingest_records( conn.commit() return next(pkey) - 1 - - def preprocess_source_table( - self, - conn: psycopg.Connection, - table_name: sql.Identifier, - column_names: list[sql.Identifier], - ) -> None: - if len(column_names) == 0: - return - - with conn.cursor() as cur: - cur.execute( - sql.SQL("ANALYZE {table_name} ({column_name})").format( - table_name=table_name, - column_name=sql.SQL(",").join(column_names), - ), - ) diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index df148fb..64c59d1 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -180,13 +180,6 @@ def _prepare_raw_table( ).as_string(), ) - def preprocess_source_table( - self, - conn: DB, - table_name: sql.Identifier, - column_names: list[sql.Identifier], - ) -> None: ... - def expand_prefix( self, prefix: str, From 7955d2167ea26d2f68fffd5bcc6f34b7071e18fd Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 26 Mar 2026 13:05:05 +0000 Subject: [PATCH 25/28] Invert json_depth check and simplify unparenting empty arrays --- src/ldlite/database/_expansion/__init__.py | 38 +++++++++++-------- .../database/_expansion/recursive_nodes.py | 12 ++++-- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/ldlite/database/_expansion/__init__.py b/src/ldlite/database/_expansion/__init__.py index 4527a58..bb33ded 100644 --- a/src/ldlite/database/_expansion/__init__.py +++ b/src/ldlite/database/_expansion/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import deque -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -12,7 +12,7 @@ from .node import Conn, Node -from .recursive_nodes import ArrayNode, ObjectNode, RecursiveNode, RootNode +from .recursive_nodes import ArrayNode, ObjectNode, RootNode def _non_srs_statements( @@ -38,28 +38,34 @@ def _non_srs_statements( onodes: deque[ObjectNode] = deque([root]) while onodes: o = onodes.popleft() - if o.depth < json_depth: - o.load_columns(conn) - scan_progress.total += len(o.direct(Node)) - else: + + if o.depth >= json_depth: o.make_jsonb() + scan_progress.update(1) + continue + + o.load_columns(conn) + scan_progress.total += len(o.direct(Node)) scan_progress.update(1) onodes.extend(o.direct(ObjectNode)) anodes = deque(o.direct(ArrayNode)) while anodes: a = anodes.popleft() - if a.depth < json_depth: - if n := a.make_temp(conn): - if isinstance(n, ObjectNode): - onodes.append(n) - if isinstance(n, ArrayNode): - anodes.append(n) - scan_progress.total += 1 - else: - cast("RecursiveNode", a.parent).remove(a) - else: + + if a.depth >= json_depth: a.make_jsonb() + scan_progress.update(1) + continue + + if n := a.make_temp(conn): + if isinstance(n, ObjectNode): + onodes.append(n) + if isinstance(n, ArrayNode): + anodes.append(n) + scan_progress.total += 1 + else: + a.unparent() scan_progress.update(1) diff --git a/src/ldlite/database/_expansion/recursive_nodes.py b/src/ldlite/database/_expansion/recursive_nodes.py index e7a32ce..c5faf26 100644 --- a/src/ldlite/database/_expansion/recursive_nodes.py +++ b/src/ldlite/database/_expansion/recursive_nodes.py @@ -63,9 +63,16 @@ def depth(self) -> int: return depth - def replace(self, original: Node, replacement: Node) -> None: + def replace(self, original: Node, replacement: Node | None) -> None: + if replacement is None: + self._children.remove(original) + return + self._children = [(replacement if n == original else n) for n in self._children] + def unparent(self) -> None: + cast("RecursiveNode", self.parent).replace(self, None) + def make_jsonb(self) -> None: cast("RecursiveNode", self.parent).replace( self, @@ -143,9 +150,6 @@ def _typed_nodes(self) -> Iterator[TypedNode]: def typed_nodes(self) -> list[TypedNode]: return list(self._typed_nodes()) - def remove(self, node: RecursiveNode) -> None: - self._children.remove(node) - class ObjectNode(RecursiveNode): def load_columns(self, conn: Conn) -> None: From 870dc1c03e9cb1431110d3332a74c847877b130c Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 26 Mar 2026 13:30:55 +0000 Subject: [PATCH 26/28] Analyze important column in raw table for postgres --- src/ldlite/database/_postgres.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/ldlite/database/_postgres.py b/src/ldlite/database/_postgres.py index 7f08175..93ca50c 100644 --- a/src/ldlite/database/_postgres.py +++ b/src/ldlite/database/_postgres.py @@ -80,6 +80,11 @@ def ingest_records( rb.extend(r) copy.write_row((next(pkey).to_bytes(4, "big"), rb)) + with conn.cursor() as cur: + cur.execute( + sql.SQL("ANALYZE {table} (jsonb);").format(table=pfx.raw_table.id), + ) + total = next(pkey) - 1 self._download_complete(conn, pfx, total, download_started) conn.commit() From e2bec99343d225b7f3db077d6640ef0ce2884694 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 26 Mar 2026 13:38:28 +0000 Subject: [PATCH 27/28] Re-enable flaky source records CI test --- .github/workflows/test.yaml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 410cdf6..dbbe015 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,7 +25,8 @@ jobs: - run: pdm install -G:all --lockfile pylock.toml # test - - run: pdm run test -k 'not test_srs' + #- run: pdm run test -k 'not test_srs' + - run: pdm run test test-minimal-deps: runs-on: ubuntu-latest @@ -41,7 +42,8 @@ jobs: - run: pdm install -G:all --lockfile pylock.minimal.toml # test - - run: pdm run test -k 'not test_srs' + #- run: pdm run test -k 'not test_srs' + - run: pdm run test test-maximal-deps: runs-on: ubuntu-latest @@ -57,4 +59,5 @@ jobs: - run: pdm install -G:all --lockfile pylock.maximal.toml # test - - run: pdm run test -k 'not test_srs' + #- run: pdm run test -k 'not test_srs' + - run: pdm run test From 82a7ffd0bb510832e40db3da793ceac1b9eed8ce Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 26 Mar 2026 14:56:05 +0000 Subject: [PATCH 28/28] Cleanup transform_progress --- src/ldlite/database/_typed_database.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index 64c59d1..41b10e6 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -200,13 +200,6 @@ def expand_prefix( conn.commit() return [] - transform_progress = ( - transform_progress - if transform_progress is not None - else tqdm(disable=True, total=0) - ) - transform_progress.total = 1 - with closing(self._conn_factory(False)) as conn: tables_to_create = non_srs_statements( conn, @@ -217,7 +210,21 @@ def expand_prefix( if scan_progress is not None else tqdm(disable=True, total=0), ) - transform_progress.total += len(tables_to_create) + 1 + + transform_progress = ( + transform_progress + if transform_progress is not None + else tqdm(disable=True, total=0) + ) + transform_progress.total = ( + ( + transform_progress.total + if transform_progress.total is not None + else 0 + ) + + len(tables_to_create) + + 1 + ) transform_progress.update(1) with self._begin(conn): @@ -227,7 +234,6 @@ def expand_prefix( cur.execute(table.as_string()) transform_progress.update(1) - # duckdb can't drop the raw table when creating the output table if not keep_raw: self._drop_raw_table(conn, pfx)