11"""Asyncmy database configuration."""
22
33from typing import TYPE_CHECKING , Any , ClassVar , TypedDict , cast
4+ from weakref import WeakSet
45
56import asyncmy
67from asyncmy .cursors import Cursor , DictCursor # pyright: ignore
1617from sqlspec .utils .config_tools import normalize_connection_config
1718
1819if TYPE_CHECKING :
19- from collections .abc import Callable
20+ from collections .abc import Awaitable , Callable
2021
2122 from asyncmy .cursors import Cursor , DictCursor # pyright: ignore
2223 from asyncmy .pool import Pool # pyright: ignore
@@ -71,6 +72,9 @@ class AsyncmyDriverFeatures(TypedDict):
7172 json_deserializer: Custom JSON deserializer function.
7273 Defaults to sqlspec.utils.serializers.from_json.
7374 Use for performance (orjson) or custom decoding.
75+ on_connection_create: Async callback executed when a connection is acquired from pool.
76+ Receives the raw asyncmy connection for low-level driver configuration.
77+ Called exactly once per physical connection using WeakSet tracking.
7478 enable_events: Enable database event channel support.
7579 Defaults to True when extension_config["events"] is configured.
7680 Provides pub/sub capabilities via table-backed queue (MySQL/MariaDB have no native pub/sub).
@@ -83,6 +87,7 @@ class AsyncmyDriverFeatures(TypedDict):
8387
8488 json_serializer : NotRequired ["Callable[[Any], str]" ]
8589 json_deserializer : NotRequired ["Callable[[str], Any]" ]
90+ on_connection_create : "NotRequired[Callable[[AsyncmyConnection], Awaitable[None]]]"
8691 enable_events : NotRequired [bool ]
8792 events_backend : NotRequired [str ]
8893
@@ -101,7 +106,9 @@ async def acquire_connection(self) -> "AsyncmyConnection":
101106 self ._config .connection_instance = pool
102107 ctx = pool .acquire ()
103108 self ._ctx = ctx
104- return cast ("AsyncmyConnection" , await ctx .__aenter__ ())
109+ connection = cast ("AsyncmyConnection" , await ctx .__aenter__ ())
110+ await self ._config ._ensure_connection_initialized (connection ) # pyright: ignore[reportPrivateUsage]
111+ return connection
105112
106113 async def release_connection (self , _conn : "AsyncmyConnection" ) -> None :
107114 if self ._ctx is not None :
@@ -125,7 +132,9 @@ async def __aenter__(self) -> AsyncmyConnection:
125132 self ._config .connection_instance = pool
126133 ctx = pool .acquire ()
127134 self ._ctx = ctx
128- return cast ("AsyncmyConnection" , await ctx .__aenter__ ())
135+ connection = cast ("AsyncmyConnection" , await ctx .__aenter__ ())
136+ await self ._config ._ensure_connection_initialized (connection ) # pyright: ignore[reportPrivateUsage]
137+ return connection
129138
130139 async def __aexit__ (
131140 self , exc_type : "type[BaseException] | None" , exc_val : "BaseException | None" , exc_tb : Any
@@ -181,12 +190,20 @@ def __init__(
181190 statement_config = statement_config or default_statement_config
182191 statement_config , driver_features = apply_driver_features (statement_config , driver_features )
183192
193+ # Extract user connection hook before storing driver_features
194+ features_dict = dict (driver_features ) if driver_features else {}
195+ self ._user_connection_hook : Callable [[AsyncmyConnection ], Awaitable [None ]] | None = features_dict .pop (
196+ "on_connection_create" , None
197+ )
198+ # Track initialized connections to ensure callback runs exactly once per physical connection
199+ self ._initialized_connections : WeakSet [Any ] = WeakSet ()
200+
184201 super ().__init__ (
185202 connection_config = connection_config ,
186203 connection_instance = connection_instance ,
187204 migration_config = migration_config ,
188205 statement_config = statement_config ,
189- driver_features = driver_features ,
206+ driver_features = features_dict ,
190207 bind_key = bind_key ,
191208 extension_config = extension_config ,
192209 observability_config = observability_config ,
@@ -205,6 +222,17 @@ async def _create_pool(self) -> "AsyncmyPool":
205222 """
206223 return cast ("AsyncmyPool" , await asyncmy .create_pool (** dict (self .connection_config )))
207224
225+ async def _ensure_connection_initialized (self , connection : "AsyncmyConnection" ) -> None :
226+ """Ensure connection callback has been called exactly once for this connection.
227+
228+ Uses WeakSet tracking to ensure the callback runs once per physical connection.
229+ """
230+ if self ._user_connection_hook is None :
231+ return
232+ if connection not in self ._initialized_connections :
233+ await self ._user_connection_hook (connection )
234+ self ._initialized_connections .add (connection )
235+
208236 async def _close_pool (self ) -> None :
209237 """Close the actual async connection pool."""
210238 if self .connection_instance :
@@ -226,7 +254,9 @@ async def create_connection(self) -> AsyncmyConnection:
226254 if pool is None :
227255 pool = await self .create_pool ()
228256 self .connection_instance = pool
229- return cast ("AsyncmyConnection" , await pool .acquire ())
257+ connection = cast ("AsyncmyConnection" , await pool .acquire ())
258+ await self ._ensure_connection_initialized (connection )
259+ return connection
230260
231261 def provide_connection (self , * args : Any , ** kwargs : Any ) -> "AsyncmyConnectionContext" :
232262 """Provide an async connection context manager.
0 commit comments