Skip to content

Commit 0593663

Browse files
authored
feat: add on_connection_create hooks and fix DuckDB variable persistence (#342)
- Adds the already implemented `on_connection_create` callback support to all adapter configs, allowing custom connection initialization logic (#340) - Fixes DuckDB `SET VARIABLE` persistence across `execute()` calls (#341)
1 parent 742ca34 commit 0593663

32 files changed

Lines changed: 963 additions & 86 deletions

File tree

sqlspec/adapters/aiosqlite/config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sqlspec.utils.logging import get_logger
2525

2626
if TYPE_CHECKING:
27-
from collections.abc import Callable
27+
from collections.abc import Awaitable, Callable
2828

2929
from sqlspec.core import StatementConfig
3030
from sqlspec.observability import ObservabilityConfig
@@ -68,6 +68,9 @@ class AiosqliteDriverFeatures(TypedDict):
6868
Defaults to sqlspec.utils.serializers.to_json.
6969
json_deserializer: Custom JSON deserializer function.
7070
Defaults to sqlspec.utils.serializers.from_json.
71+
on_connection_create: Async callback executed when a connection is created.
72+
Receives the raw aiosqlite connection for low-level driver configuration.
73+
Runs after internal setup (PRAGMA optimizations).
7174
enable_events: Enable database event channel support.
7275
Defaults to True when extension_config["events"] is configured.
7376
Provides pub/sub capabilities via table-backed queue (SQLite has no native pub/sub).
@@ -81,6 +84,7 @@ class AiosqliteDriverFeatures(TypedDict):
8184
enable_custom_adapters: NotRequired[bool]
8285
json_serializer: "NotRequired[Callable[[Any], str]]"
8386
json_deserializer: "NotRequired[Callable[[str], Any]]"
87+
on_connection_create: "NotRequired[Callable[[AiosqliteConnection], Awaitable[None]]]"
8488
enable_events: NotRequired[bool]
8589
events_backend: NotRequired[str]
8690

@@ -191,12 +195,18 @@ def __init__(
191195
statement_config = statement_config or default_statement_config
192196
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
193197

198+
# Extract user connection hook before storing driver_features
199+
features_dict = dict(driver_features) if driver_features else {}
200+
self._user_connection_hook: Callable[[AiosqliteConnection], Awaitable[None]] | None = features_dict.pop(
201+
"on_connection_create", None
202+
)
203+
194204
super().__init__(
195205
connection_config=config_dict,
196206
connection_instance=connection_instance,
197207
migration_config=migration_config,
198208
statement_config=statement_config,
199-
driver_features=driver_features,
209+
driver_features=features_dict,
200210
bind_key=bind_key,
201211
extension_config=extension_config,
202212
observability_config=observability_config,
@@ -258,6 +268,7 @@ async def _create_pool(self) -> AiosqliteConnectionPool:
258268
connect_timeout=connect_timeout,
259269
idle_timeout=idle_timeout,
260270
operation_timeout=operation_timeout,
271+
on_connection_create=self._user_connection_hook,
261272
)
262273

263274
if self.driver_features.get("enable_custom_adapters", False):

sqlspec/adapters/aiosqlite/pool.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sqlspec.utils.uuids import uuid4
1414

1515
if TYPE_CHECKING:
16+
from collections.abc import Awaitable, Callable
1617
from types import TracebackType
1718

1819
from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection
@@ -178,6 +179,7 @@ class AiosqliteConnectionPool:
178179
"_idle_timeout",
179180
"_lock_instance",
180181
"_min_size",
182+
"_on_connection_create",
181183
"_operation_timeout",
182184
"_pool_id",
183185
"_pool_size",
@@ -195,6 +197,7 @@ def __init__(
195197
idle_timeout: float = 24 * 60 * 60,
196198
operation_timeout: float = 10.0,
197199
health_check_interval: float = 30.0,
200+
on_connection_create: "Callable[[AiosqliteConnection], Awaitable[None]] | None" = None,
198201
) -> None:
199202
"""Initialize connection pool.
200203
@@ -206,6 +209,7 @@ def __init__(
206209
idle_timeout: Maximum time a connection can remain idle
207210
operation_timeout: Maximum time for connection operations
208211
health_check_interval: Seconds of idle time before running health check
212+
on_connection_create: Async callback executed when connection is created
209213
"""
210214
self._connection_parameters = connection_parameters
211215
self._pool_size = pool_size
@@ -214,6 +218,7 @@ def __init__(
214218
self._idle_timeout = idle_timeout
215219
self._operation_timeout = operation_timeout
216220
self._health_check_interval = health_check_interval
221+
self._on_connection_create = on_connection_create
217222

218223
self._connection_registry: dict[str, AiosqlitePoolConnection] = {}
219224
self._wal_initialized = False
@@ -324,6 +329,10 @@ async def _create_connection(self) -> AiosqlitePoolConnection:
324329
await connection.execute("PRAGMA busy_timeout = 30000")
325330
await connection.commit()
326331

332+
# Call user-provided callback after internal setup
333+
if self._on_connection_create is not None:
334+
await self._on_connection_create(connection)
335+
327336
pool_connection = AiosqlitePoolConnection(connection)
328337
pool_connection.mark_as_idle()
329338

sqlspec/adapters/asyncmy/config.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Asyncmy database configuration."""
22

33
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
4+
from weakref import WeakSet
45

56
import asyncmy
67
from asyncmy.cursors import Cursor, DictCursor # pyright: ignore
@@ -16,7 +17,7 @@
1617
from sqlspec.utils.config_tools import normalize_connection_config
1718

1819
if TYPE_CHECKING:
19-
from collections.abc import Callable
20+
from collections.abc import Awaitable, Callable
2021

2122
from asyncmy.cursors import Cursor, DictCursor # pyright: ignore
2223
from asyncmy.pool import Pool # pyright: ignore
@@ -71,6 +72,9 @@ class AsyncmyDriverFeatures(TypedDict):
7172
json_deserializer: Custom JSON deserializer function.
7273
Defaults to sqlspec.utils.serializers.from_json.
7374
Use for performance (orjson) or custom decoding.
75+
on_connection_create: Async callback executed when a connection is acquired from pool.
76+
Receives the raw asyncmy connection for low-level driver configuration.
77+
Called exactly once per physical connection using WeakSet tracking.
7478
enable_events: Enable database event channel support.
7579
Defaults to True when extension_config["events"] is configured.
7680
Provides pub/sub capabilities via table-backed queue (MySQL/MariaDB have no native pub/sub).
@@ -83,6 +87,7 @@ class AsyncmyDriverFeatures(TypedDict):
8387

8488
json_serializer: NotRequired["Callable[[Any], str]"]
8589
json_deserializer: NotRequired["Callable[[str], Any]"]
90+
on_connection_create: "NotRequired[Callable[[AsyncmyConnection], Awaitable[None]]]"
8691
enable_events: NotRequired[bool]
8792
events_backend: NotRequired[str]
8893

@@ -101,7 +106,9 @@ async def acquire_connection(self) -> "AsyncmyConnection":
101106
self._config.connection_instance = pool
102107
ctx = pool.acquire()
103108
self._ctx = ctx
104-
return cast("AsyncmyConnection", await ctx.__aenter__())
109+
connection = cast("AsyncmyConnection", await ctx.__aenter__())
110+
await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage]
111+
return connection
105112

106113
async def release_connection(self, _conn: "AsyncmyConnection") -> None:
107114
if self._ctx is not None:
@@ -125,7 +132,9 @@ async def __aenter__(self) -> AsyncmyConnection:
125132
self._config.connection_instance = pool
126133
ctx = pool.acquire()
127134
self._ctx = ctx
128-
return cast("AsyncmyConnection", await ctx.__aenter__())
135+
connection = cast("AsyncmyConnection", await ctx.__aenter__())
136+
await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage]
137+
return connection
129138

130139
async def __aexit__(
131140
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any
@@ -181,12 +190,20 @@ def __init__(
181190
statement_config = statement_config or default_statement_config
182191
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
183192

193+
# Extract user connection hook before storing driver_features
194+
features_dict = dict(driver_features) if driver_features else {}
195+
self._user_connection_hook: Callable[[AsyncmyConnection], Awaitable[None]] | None = features_dict.pop(
196+
"on_connection_create", None
197+
)
198+
# Track initialized connections to ensure callback runs exactly once per physical connection
199+
self._initialized_connections: WeakSet[Any] = WeakSet()
200+
184201
super().__init__(
185202
connection_config=connection_config,
186203
connection_instance=connection_instance,
187204
migration_config=migration_config,
188205
statement_config=statement_config,
189-
driver_features=driver_features,
206+
driver_features=features_dict,
190207
bind_key=bind_key,
191208
extension_config=extension_config,
192209
observability_config=observability_config,
@@ -205,6 +222,17 @@ async def _create_pool(self) -> "AsyncmyPool":
205222
"""
206223
return cast("AsyncmyPool", await asyncmy.create_pool(**dict(self.connection_config)))
207224

225+
async def _ensure_connection_initialized(self, connection: "AsyncmyConnection") -> None:
226+
"""Ensure connection callback has been called exactly once for this connection.
227+
228+
Uses WeakSet tracking to ensure the callback runs once per physical connection.
229+
"""
230+
if self._user_connection_hook is None:
231+
return
232+
if connection not in self._initialized_connections:
233+
await self._user_connection_hook(connection)
234+
self._initialized_connections.add(connection)
235+
208236
async def _close_pool(self) -> None:
209237
"""Close the actual async connection pool."""
210238
if self.connection_instance:
@@ -226,7 +254,9 @@ async def create_connection(self) -> AsyncmyConnection:
226254
if pool is None:
227255
pool = await self.create_pool()
228256
self.connection_instance = pool
229-
return cast("AsyncmyConnection", await pool.acquire())
257+
connection = cast("AsyncmyConnection", await pool.acquire())
258+
await self._ensure_connection_initialized(connection)
259+
return connection
230260

231261
def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncmyConnectionContext":
232262
"""Provide an async connection context manager.

sqlspec/adapters/asyncpg/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,18 @@ def __init__(
307307
statement_config = statement_config or default_statement_config
308308
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
309309

310+
# Extract user connection hook before storing driver_features
311+
features_dict = dict(driver_features) if driver_features else {}
312+
self._user_connection_hook: Callable[[AsyncpgConnection], Awaitable[None]] | None = features_dict.pop(
313+
"on_connection_create", None
314+
)
315+
310316
super().__init__(
311317
connection_config=normalize_connection_config(connection_config),
312318
connection_instance=connection_instance,
313319
migration_config=migration_config,
314320
statement_config=statement_config,
315-
driver_features=driver_features,
321+
driver_features=features_dict,
316322
bind_key=bind_key,
317323
extension_config=extension_config,
318324
observability_config=observability_config,
@@ -432,7 +438,7 @@ async def _create_pool(self) -> "Pool[Record]":
432438
return await asyncpg_create_pool(**config)
433439

434440
async def _init_connection(self, connection: "AsyncpgConnection") -> None:
435-
"""Initialize connection with JSON codecs and pgvector support.
441+
"""Initialize connection with JSON codecs, pgvector support, and user callback.
436442
437443
Args:
438444
connection: AsyncPG connection to initialize.
@@ -456,6 +462,10 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None:
456462
if self._pgvector_available:
457463
await register_pgvector_support(connection)
458464

465+
# Call user-provided callback after internal setup
466+
if self._user_connection_hook is not None:
467+
await self._user_connection_hook(connection)
468+
459469
async def _close_pool(self) -> None:
460470
"""Close the actual async connection pool and cleanup connectors."""
461471
if self.connection_instance:

sqlspec/adapters/cockroach_asyncpg/config.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@ class CockroachAsyncpgPoolConfig(CockroachAsyncpgConnectionConfig):
6464

6565

6666
class CockroachAsyncpgDriverFeatures(TypedDict):
67-
"""Driver feature flags for CockroachDB AsyncPG adapter."""
67+
"""Driver feature flags for CockroachDB AsyncPG adapter.
68+
69+
on_connection_create: Async callback executed when a connection is acquired from pool.
70+
Receives the raw asyncpg connection for low-level driver configuration.
71+
Called after internal setup (JSON codecs, pgvector registration).
72+
"""
6873

6974
enable_auto_retry: NotRequired[bool]
7075
max_retries: NotRequired[int]
@@ -77,6 +82,7 @@ class CockroachAsyncpgDriverFeatures(TypedDict):
7782
json_deserializer: NotRequired["Callable[[str], Any]"]
7883
enable_json_codecs: NotRequired[bool]
7984
enable_pgvector: NotRequired[bool]
85+
on_connection_create: "NotRequired[Callable[[CockroachAsyncpgConnection], Awaitable[None]]]"
8086
enable_events: NotRequired[bool]
8187
events_backend: NotRequired[str]
8288

@@ -163,12 +169,18 @@ def __init__(
163169
driver_features.setdefault("enable_auto_retry", True)
164170
_ = CockroachAsyncpgRetryConfig.from_features(driver_features)
165171

172+
# Extract user connection hook before storing driver_features
173+
features_dict = dict(driver_features) if driver_features else {}
174+
self._user_connection_hook: Callable[[CockroachAsyncpgConnection], Awaitable[None]] | None = features_dict.pop(
175+
"on_connection_create", None
176+
)
177+
166178
super().__init__(
167179
connection_config=connection_config,
168180
connection_instance=connection_instance,
169181
migration_config=migration_config,
170182
statement_config=statement_config,
171-
driver_features=driver_features,
183+
driver_features=features_dict,
172184
bind_key=bind_key,
173185
extension_config=extension_config,
174186
observability_config=observability_config,
@@ -177,8 +189,14 @@ def __init__(
177189

178190
async def _create_pool(self) -> "CockroachAsyncpgPool":
179191
config = build_connection_config(self.connection_config)
192+
config.setdefault("init", self._init_connection)
180193
return await asyncpg_create_pool(**config)
181194

195+
async def _init_connection(self, connection: "CockroachAsyncpgConnection") -> None:
196+
"""Initialize connection with user callback if provided."""
197+
if self._user_connection_hook is not None:
198+
await self._user_connection_hook(connection)
199+
182200
async def _close_pool(self) -> None:
183201
if not self.connection_instance:
184202
return

0 commit comments

Comments
 (0)