diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 43f3859a..9eee405d 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -5,6 +5,7 @@ on: push: branches: - main + - feat/tortoise permissions: id-token: write # This is required for requesting the JWT diff --git a/aws_advanced_python_wrapper/custom_endpoint_plugin.py b/aws_advanced_python_wrapper/custom_endpoint_plugin.py index 03b672a6..8e490e02 100644 --- a/aws_advanced_python_wrapper/custom_endpoint_plugin.py +++ b/aws_advanced_python_wrapper/custom_endpoint_plugin.py @@ -169,7 +169,8 @@ def _run(self): len(endpoints), endpoint_hostnames) - sleep(self._refresh_rate_ns / 1_000_000_000) + if self._stop_event.wait(self._refresh_rate_ns / 1_000_000_000): + break continue endpoint_info = CustomEndpointInfo.from_db_cluster_endpoint(endpoints[0]) @@ -178,7 +179,8 @@ def _run(self): if cached_info is not None and cached_info == endpoint_info: elapsed_time = perf_counter_ns() - start_ns sleep_duration = max(0, self._refresh_rate_ns - elapsed_time) - sleep(sleep_duration / 1_000_000_000) + if self._stop_event.wait(sleep_duration / 1_000_000_000): + break continue logger.debug( @@ -196,7 +198,8 @@ def _run(self): elapsed_time = perf_counter_ns() - start_ns sleep_duration = max(0, self._refresh_rate_ns - elapsed_time) - sleep(sleep_duration / 1_000_000_000) + if self._stop_event.wait(sleep_duration / 1_000_000_000): + break continue except InterruptedError as e: raise e diff --git a/aws_advanced_python_wrapper/errors.py b/aws_advanced_python_wrapper/errors.py index b265c94e..7fa44f83 100644 --- a/aws_advanced_python_wrapper/errors.py +++ b/aws_advanced_python_wrapper/errors.py @@ -45,3 +45,31 @@ class FailoverSuccessError(FailoverError): class ReadWriteSplittingError(AwsWrapperError): __module__ = "aws_advanced_python_wrapper" + + +class AsyncConnectionPoolError(AwsWrapperError): + __module__ = "aws_advanced_python_wrapper" + + +class PoolNotInitializedError(AsyncConnectionPoolError): + __module__ = "aws_advanced_python_wrapper" + + +class PoolClosingError(AsyncConnectionPoolError): + __module__ = "aws_advanced_python_wrapper" + + +class PoolExhaustedError(AsyncConnectionPoolError): + __module__ = "aws_advanced_python_wrapper" + + +class ConnectionReleasedError(AsyncConnectionPoolError): + __module__ = "aws_advanced_python_wrapper" + + +class PoolSizeLimitError(AsyncConnectionPoolError): + __module__ = "aws_advanced_python_wrapper" + + +class PoolHealthCheckError(AsyncConnectionPoolError): + __module__ = "aws_advanced_python_wrapper" diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index a503be4c..da67a939 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -114,7 +114,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl is_cached_token = (token_info is not None and not token_info.is_expired()) if not self._plugin_service.is_login_exception(error=e) or not is_cached_token: - raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.ConnectException", e)) from e + raise # Login unsuccessful with cached token # Try to generate a new token and try to connect again diff --git a/aws_advanced_python_wrapper/tortoise_orm/__init__.py b/aws_advanced_python_wrapper/tortoise_orm/__init__.py new file mode 100644 index 00000000..503c5ea2 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/__init__.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise.backends.base.config_generator import DB_LOOKUP + + +def cast_to_bool(value): + """Generic function to cast various types to boolean.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() in ('true', '1', 'yes', 'on') + return bool(value) + + +# Register AWS MySQL backend +DB_LOOKUP["aws-mysql"] = { + "engine": "aws_advanced_python_wrapper.tortoise_orm.backends.mysql", + "vmap": { + "path": "database", + "hostname": "host", + "port": "port", + "username": "user", + "password": "password", + }, + "defaults": {"port": 3306, "charset": "utf8mb4", "sql_mode": "STRICT_TRANS_TABLES"}, + "cast": { + "minsize": int, + "maxsize": int, + "connect_timeout": int, + "echo": cast_to_bool, + "use_unicode": cast_to_bool, + "ssl": cast_to_bool, + "use_pure": cast_to_bool + }, +} diff --git a/aws_advanced_python_wrapper/tortoise_orm/async_support/__init__.py b/aws_advanced_python_wrapper/tortoise_orm/async_support/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/async_support/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/tortoise_orm/async_support/async_connection_pool.py b/aws_advanced_python_wrapper/tortoise_orm/async_support/async_connection_pool.py new file mode 100644 index 00000000..5b60a25b --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/async_support/async_connection_pool.py @@ -0,0 +1,502 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Generic async connection pool - manual connection management. +User controls when to acquire and release connections. +""" +import asyncio +import logging +import time +from contextlib import asynccontextmanager +from dataclasses import dataclass +from enum import Enum +from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar + +from aws_advanced_python_wrapper.errors import (ConnectionReleasedError, + PoolClosingError, + PoolExhaustedError, + PoolHealthCheckError, + PoolNotInitializedError) + +logger = logging.getLogger(__name__) + +T_con = TypeVar('T_con') + + +class ConnectionState(Enum): + IDLE = "idle" + IN_USE = "in_use" + CLOSED = "closed" + + +@dataclass +class PoolConfig: + """Pool configuration""" + min_size: int = 1 + max_size: int = 20 + overflow: int = 0 # Additional connections beyond max_size. -1 = unlimited + acquire_conn_timeout: float = 60.0 # Max time to wait for connection acquisition + max_conn_lifetime: float = 3600.0 # 1 hour + max_conn_idle_time: float = 600.0 # 10 minutes + health_check_interval: float = 30.0 + health_check_timeout: float = 5.0 # Health check timeout + pre_ping: bool = True + + +class AsyncPooledConnectionWrapper: + """Combined wrapper for pooled connections with metadata and user interface""" + + def __init__(self, connection: Any, connection_id: int, pool: 'AsyncConnectionPool'): + # Pool metadata + self.connection = connection + self.connection_id = connection_id + self.created_at = time.monotonic() + self.last_used = time.monotonic() + self.use_count = 0 + self.state = ConnectionState.IDLE + + # User interface + self._pool = pool + self._released = False + + def mark_in_use(self): + self.state = ConnectionState.IN_USE + self.last_used = time.monotonic() + self.use_count += 1 + self._released = False + + def mark_idle(self): + self.state = ConnectionState.IDLE + self.last_used = time.monotonic() + + def mark_closed(self): + self.state = ConnectionState.CLOSED + + @property + def age(self) -> float: + return time.monotonic() - self.created_at + + @property + def idle_time(self) -> float: + return time.monotonic() - self.last_used + + def is_stale(self, max_conn_lifetime: float, max_conn_idle_time: float) -> bool: + return ( + self.age > max_conn_lifetime or + (self.state == ConnectionState.IDLE and self.idle_time > max_conn_idle_time) + ) + + # User interface methods + async def release(self): + """Return connection to the pool""" + if not self._released: + self._released = True + await self._pool._return_connection(self) + + async def close(self): + """Alias for release()""" + await self.release() + + def __getattr__(self, name): + """Proxy attribute access to underlying connection""" + if self._released: + raise ConnectionReleasedError("Connection already released to pool") + return getattr(self.connection, name) + + def __del__(self): + """Warn if connection not released""" + if not self._released and self.state == ConnectionState.IN_USE: + logger.warning( + f"Connection {self.connection_id} was not released! " + f"Always call release() or use context manager." + ) + + +class AsyncConnectionPool: + """ + Generic async connection pool with manual connection management. + """ + + @staticmethod + async def _default_closer(connection: Any) -> None: + """Default connection closer that handles both sync and async close methods""" + if hasattr(connection, 'close'): + close_method = connection.close + if asyncio.iscoroutinefunction(close_method): + await close_method() + else: + close_method() + + @staticmethod + async def _default_health_check(connection: Any) -> None: + """Default health check that verifies connection is not closed""" + is_closed = await asyncio.to_thread(lambda: connection.is_closed) + if is_closed: + raise PoolHealthCheckError("Connection is closed") + + def __init__( + self, + creator: Callable[[], Awaitable[T_con]], + health_check: Optional[Callable[[T_con], Awaitable[None]]] = None, + closer: Optional[Callable[[T_con], Awaitable[None]]] = None, + config: Optional[PoolConfig] = None + ): + self._creator = creator + self._health_check = health_check or self._default_health_check + self._closer = closer or self._default_closer + self._config = config or PoolConfig() + + # Pool state - queue size accounts for overflow + self._pool: asyncio.Queue[AsyncPooledConnectionWrapper] = asyncio.Queue() + self._all_connections: Dict[int, AsyncPooledConnectionWrapper] = {} + self._connection_counter = 0 + self._max_connection_id = 1000000 # Reset after 1M connections + self._size = 0 + + # Synchronization + self._lock = asyncio.Lock() + self._id_lock = asyncio.Lock() # Separate lock for connection ID generation + self._closing = False + self._initialized = False + + # Background tasks + self._maintenance_task: Optional[asyncio.Task] = None + + async def initialize(self): + """Initialize the pool with minimum connections""" + if self._initialized: + logger.warning("Pool already initialized") + return + + logger.info(f"Initializing pool with {self._config.min_size} connections") + + try: + # Create initial connections + tasks = [ + self._create_connection() + for _ in range(self._config.min_size) + ] + connections = await asyncio.gather(*tasks) # Remove return_exceptions=True + + async with self._lock: + for conn in connections: + self._size += 1 + await self._pool.put(conn) + + # Start maintenance task + self._maintenance_task = asyncio.create_task(self._maintenance_loop()) + + self._initialized = True + logger.info(f"Pool initialized with {self._size} connections") + + except Exception as e: + logger.error(f"Failed to initialize pool: {e}") + await self.close() + raise + + async def _get_next_connection_id(self) -> int: + """Get next unique connection ID with cycling to prevent overflow""" + async with self._id_lock: + while True: + self._connection_counter += 1 + if self._connection_counter > self._max_connection_id: + self._connection_counter = 1 + + # Check if ID is in use (avoid nested lock by checking outside) + if self._connection_counter not in self._all_connections: + return self._connection_counter + + async def _create_connection(self) -> AsyncPooledConnectionWrapper: + """Create a new connection (caller manages size)""" + connection_id = await self._get_next_connection_id() + + try: + raw_conn = await self._creator() + pooled_conn = AsyncPooledConnectionWrapper(raw_conn, connection_id, self) + + async with self._lock: + self._all_connections[connection_id] = pooled_conn + + logger.debug(f"Created connection {connection_id}, pool size: {self._size}") + return pooled_conn + + except Exception as e: + logger.error(f"Failed to create connection: {e}") + raise + + async def _close_connection(self, pooled_conn: AsyncPooledConnectionWrapper): + """Close a connection""" + if pooled_conn.state == ConnectionState.CLOSED: + return + + pooled_conn.mark_closed() + + try: + # Close the underlying connection + await self._closer(pooled_conn.connection) + + # Remove from tracking + async with self._lock: + self._all_connections.pop(pooled_conn.connection_id, None) + self._size -= 1 + + except Exception as e: + logger.error(f"Error closing connection {pooled_conn.connection_id}: {e}") + + async def _validate_connection(self, pooled_conn: AsyncPooledConnectionWrapper) -> bool: + """Validate a connection""" + # Check if stale + if pooled_conn.is_stale( + self._config.max_conn_lifetime, + self._config.max_conn_idle_time + ): + logger.debug(f"Connection {pooled_conn.connection_id} is stale") + return False + + # Run health check if configured + if self._config.pre_ping and self._health_check: + try: + await asyncio.wait_for( + self._health_check(pooled_conn.connection), + timeout=self._config.health_check_timeout + ) + return True + except Exception as e: + logger.warning( + f"Health check failed for connection " + f"{pooled_conn.connection_id}: {e}" + ) + return False + + return True + + async def acquire(self) -> AsyncPooledConnectionWrapper: + """ + Acquire a connection from the pool. + YOU must call release() when done! + + Returns: + AsyncPooledConnectionWrapper: Connection with .release() method and direct attribute access + """ + if not self._initialized: + raise PoolNotInitializedError("Pool not initialized. Call await pool.initialize() first") + + if self._closing: + raise PoolClosingError("Pool is closing") + + pooled_conn = None + created_new = False + + try: + # Try to get idle connection, create new one, or wait + try: + pooled_conn = self._pool.get_nowait() + except asyncio.QueueEmpty: + # Atomic check and reserve slot + async with self._lock: + max_total = self._config.max_size + (float('inf') if self._config.overflow == -1 else self._config.overflow) + + if self._size < max_total: + self._size += 1 # Reserve slot + create_new = True + else: + create_new = False + + if create_new: + try: + connection_id = await self._get_next_connection_id() + # Create connection with timeout to prevent hanging during failover + raw_conn = await asyncio.wait_for( + self._creator(), + timeout=min(self._config.acquire_conn_timeout, 30.0) # Cap at 30s to prevent indefinite hangs + ) + pooled_conn = AsyncPooledConnectionWrapper(raw_conn, connection_id, self) + + # Add to tracking + async with self._lock: + self._all_connections[connection_id] = pooled_conn + created_new = True + except Exception: + # Failed to create, decrement reserved size + async with self._lock: + self._size -= 1 + raise + else: + pooled_conn = await asyncio.wait_for(self._pool.get(), timeout=self._config.acquire_conn_timeout) + + # Validate and recreate if needed + if not await self._validate_connection(pooled_conn): + await self._close_connection(pooled_conn) + # Recreate using same pattern as initial creation to avoid deadlock + async with self._lock: + self._size += 1 # Reserve slot + + try: + connection_id = await self._get_next_connection_id() + raw_conn = await asyncio.wait_for( + self._creator(), + timeout=min(self._config.acquire_conn_timeout, 30.0) # Cap at 30s to prevent indefinite hangs + ) + pooled_conn = AsyncPooledConnectionWrapper(raw_conn, connection_id, self) + async with self._lock: + self._all_connections[connection_id] = pooled_conn + created_new = True + except Exception: + async with self._lock: + self._size -= 1 + raise + + pooled_conn.mark_in_use() + return pooled_conn + + except asyncio.TimeoutError: + raise PoolExhaustedError( + f"Pool exhausted: timeout after {self._config.acquire_conn_timeout}s, " + f"size: {self._size}/{self._config.max_size}, " + f"idle: {self._pool.qsize()}" + ) + except Exception as e: + logger.error(f"Error acquiring connection: {e}") + # If we created a new connection and got error, close it + if created_new and pooled_conn: + await self._close_connection(pooled_conn) + raise + + async def _return_connection(self, pooled_conn: AsyncPooledConnectionWrapper): + """Return connection to pool or close if excess""" + if self._closing: + await self._close_connection(pooled_conn) + return + + # Check if we should close excess connections (lock released before close) + should_close = False + async with self._lock: + should_close = self._pool.qsize() >= self._config.max_size + + if should_close: + await self._close_connection(pooled_conn) + else: + try: + pooled_conn.mark_idle() + await self._pool.put(pooled_conn) + except Exception as e: + logger.error(f"Error returning connection: {e}") + await self._close_connection(pooled_conn) + + @asynccontextmanager + async def connection(self): + """ + Context manager for automatic connection management. + + Usage: + async with pool.connection() as conn: + result = await conn.fetchval("SELECT 1") + """ + pool_conn = await self.acquire() + try: + yield pool_conn + finally: + await pool_conn.release() + + async def _maintenance_loop(self): + """Background task to maintain pool health""" + while not self._closing: + try: + await asyncio.sleep(self._config.health_check_interval) + if self._closing: + break + + # Create connections if below minimum + needed = 0 + async with self._lock: + needed = self._config.min_size - self._size + if needed > 0: + self._size += needed # Reserve slots + + if needed > 0: + for _ in range(needed): + try: + conn = await self._create_connection() + await self._pool.put(conn) + except Exception as e: + async with self._lock: + self._size -= 1 # Release reserved slot on failure + logger.error(f"Maintenance connection creation failed: {e}") + + # Remove stale idle connections (collect under lock, close outside) + stale_conns = [] + async with self._lock: + stale_conns = [ + conn for conn in self._all_connections.values() + if conn.state == ConnectionState.IDLE and + conn.is_stale(self._config.max_conn_lifetime, self._config.max_conn_idle_time) + ] + + # Close stale connections outside the lock to avoid deadlock + for conn in stale_conns: + try: + await self._close_connection(conn) + except Exception as e: + logger.error(f"Stale connection cleanup failed: {e}") + + except Exception as e: + logger.error(f"Maintenance loop error: {e}") + + async def close(self): + """Close the pool and all connections""" + if self._closing: + return + + self._closing = True + + # Cancel maintenance task + if self._maintenance_task: + self._maintenance_task.cancel() + try: + await self._maintenance_task + except asyncio.CancelledError: + pass + + # Close all connections (collect under lock, close outside to avoid deadlock) + async with self._lock: + connections = list(self._all_connections.values()) + + await asyncio.gather( + *[self._close_connection(conn) for conn in connections], + return_exceptions=True + ) + + def get_stats(self) -> Dict[str, Any]: + """Get pool statistics""" + # Note: This is a sync method, so we can't use async lock + # Stats may be slightly inconsistent but that's acceptable for monitoring + try: + states = [conn.state for conn in self._all_connections.values()] + idle_count = states.count(ConnectionState.IDLE) + in_use_count = states.count(ConnectionState.IN_USE) + except RuntimeError: # Dictionary changed during iteration + idle_count = in_use_count = 0 + + return { + "total_size": self._size, + "idle": idle_count, + "in_use": in_use_count, + "available_in_queue": self._pool.qsize(), + "max_size": self._config.max_size, + "overflow": self._config.overflow, + "min_size": self._config.min_size, + "initialized": self._initialized, + "closing": self._closing, + } diff --git a/aws_advanced_python_wrapper/tortoise_orm/async_support/async_wrapper.py b/aws_advanced_python_wrapper/tortoise_orm/async_support/async_wrapper.py new file mode 100644 index 00000000..6973bde8 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/async_support/async_wrapper.py @@ -0,0 +1,111 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import Callable + +from aws_advanced_python_wrapper import AwsWrapperConnection + + +class AwsWrapperAsyncConnector(): + """Class for creating and closing AWS wrapper connections.""" + + @staticmethod + async def connect_with_aws_wrapper(connect_func: Callable, **kwargs) -> AwsConnectionAsyncWrapper: + """Create an AWS wrapper connection with async cursor support.""" + connection = await asyncio.to_thread( + AwsWrapperConnection.connect, connect_func, **kwargs + ) + return AwsConnectionAsyncWrapper(connection) + + @staticmethod + async def close_aws_wrapper(connection: AwsWrapperConnection) -> None: + """Close an AWS wrapper connection asynchronously.""" + await asyncio.to_thread(connection.close) + + +class AwsConnectionAsyncWrapper(AwsWrapperConnection): + """Wraps sync AwsConnection with async cursor support.""" + + def __init__(self, connection: AwsWrapperConnection): + self._wrapped_connection = connection + + @asynccontextmanager + async def cursor(self): + """Create an async cursor context manager.""" + cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor) + try: + yield AwsCursorAsyncWrapper(cursor_obj) + finally: + await asyncio.to_thread(cursor_obj.close) + + async def rollback(self): + """Rollback the current transaction.""" + return await asyncio.to_thread(self._wrapped_connection.rollback) + + async def commit(self): + """Commit the current transaction.""" + return await asyncio.to_thread(self._wrapped_connection.commit) + + async def set_autocommit(self, value: bool): + """Set autocommit mode.""" + return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value) + + async def close(self): + """Close the connection asynchronously.""" + return await asyncio.to_thread(self._wrapped_connection.close) + + def __getattr__(self, name): + """Delegate all other attributes/methods to the wrapped connection.""" + return getattr(self._wrapped_connection, name) + + def __del__(self): + """Delegate cleanup to wrapped connection.""" + if hasattr(self, '_wrapped_connection'): + # Let the wrapped connection handle its own cleanup + pass + + +class AwsCursorAsyncWrapper: + """Wraps sync AwsCursor cursor with async support.""" + + def __init__(self, sync_cursor): + self._cursor = sync_cursor + + async def execute(self, query, params=None): + """Execute a query asynchronously.""" + return await asyncio.to_thread(self._cursor.execute, query, params) + + async def executemany(self, query, params_list): + """Execute multiple queries asynchronously.""" + return await asyncio.to_thread(self._cursor.executemany, query, params_list) + + async def fetchall(self): + """Fetch all results asynchronously.""" + return await asyncio.to_thread(self._cursor.fetchall) + + async def fetchone(self): + """Fetch one result asynchronously.""" + return await asyncio.to_thread(self._cursor.fetchone) + + async def close(self): + """Close cursor asynchronously.""" + return await asyncio.to_thread(self._cursor.close) + + def __getattr__(self, name): + """Delegate non-async attributes to the wrapped cursor.""" + return getattr(self._cursor, name) diff --git a/aws_advanced_python_wrapper/tortoise_orm/backends/__init__.py b/aws_advanced_python_wrapper/tortoise_orm/backends/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/backends/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/tortoise_orm/backends/base/__init__.py b/aws_advanced_python_wrapper/tortoise_orm/backends/base/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/backends/base/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/tortoise_orm/backends/base/client.py b/aws_advanced_python_wrapper/tortoise_orm/backends/base/client.py new file mode 100644 index 00000000..0a4a8c23 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/backends/base/client.py @@ -0,0 +1,121 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Generic, cast + +from tortoise.backends.base.client import (BaseDBAsyncClient, T_conn, + TransactionalDBClient, + TransactionContext) +from tortoise.connection import connections +from tortoise.exceptions import TransactionManagementError + +if TYPE_CHECKING: + from asyncio import Lock + + from aws_advanced_python_wrapper.tortoise_orm.async_support.async_connection_pool import \ + AsyncPooledConnectionWrapper + + +class AwsBaseDBAsyncClient(BaseDBAsyncClient): + _template: Dict[str, Any] + + +class AwsTransactionalDBClient(TransactionalDBClient): + _template: Dict[str, Any] + _parent: AwsBaseDBAsyncClient + pass + + +class TortoiseAwsClientPooledConnectionWrapper(Generic[T_conn]): + """Manages acquiring from and releasing connections to a pool.""" + + __slots__ = ("client", "connection", "_pool_init_lock",) + + def __init__( + self, + client: BaseDBAsyncClient, + pool_init_lock: Lock, + ) -> None: + self.client = client + self.connection: AsyncPooledConnectionWrapper | None = None + self._pool_init_lock = pool_init_lock + + async def ensure_connection(self) -> None: + """Ensure the connection pool is initialized.""" + if not self.client._pool: + async with self._pool_init_lock: + if not self.client._pool: + await self.client.create_connection(with_db=True) + + async def __aenter__(self) -> AsyncPooledConnectionWrapper: + """Acquire connection from pool.""" + await self.ensure_connection() + self.connection = await self.client._pool.acquire() + return cast('AsyncPooledConnectionWrapper', self.connection) + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Close connection and release back to pool.""" + if self.connection: + await self.connection.release() + + +class TortoiseAwsClientPooledTransactionContext(TransactionContext): + """Transaction context that uses a pool to acquire connections.""" + + __slots__ = ("client", "connection_name", "token", "_pool_init_lock", "connection") + + def __init__(self, client: TransactionalDBClient, pool_init_lock: Lock) -> None: + self.client = client + self.connection_name = client.connection_name + self._pool_init_lock = pool_init_lock + self.connection = None + + async def ensure_connection(self) -> None: + """Ensure the connection pool is initialized.""" + if not self.client._parent._pool: + # a safeguard against multiple concurrent tasks trying to initialize the pool + async with self._pool_init_lock: + if not self.client._parent._pool: + await self.client._parent.create_connection(with_db=True) + + async def __aenter__(self) -> TransactionalDBClient: + """Enter transaction context.""" + await self.ensure_connection() + + # Set the context variable so the current task sees a TransactionWrapper connection + self.token = connections.set(self.connection_name, self.client) + + # Create connection and begin transaction + self.connection = await self.client._parent._pool.acquire() + self.client._connection = self.connection + await self.client.begin() + return self.client + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit transaction context with proper cleanup.""" + try: + if not self.client._finalized: + if exc_type: + # Can't rollback a transaction that already failed + if exc_type is not TransactionManagementError: + await self.client.rollback() + else: + await self.client.commit() + finally: + if self.client._connection: + await self.client._connection.release() + # self.client._connection = None + connections.reset(self.token) diff --git a/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/__init__.py b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/__init__.py new file mode 100644 index 00000000..1092dd31 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .client import AwsMySQLClient + +client_class = AwsMySQLClient diff --git a/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/client.py b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/client.py new file mode 100644 index 00000000..d396901a --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/client.py @@ -0,0 +1,379 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import fields +from functools import wraps +from itertools import count +from typing import (Any, Callable, Coroutine, Dict, List, Optional, + SupportsInt, Tuple, TypeVar) + +import mysql.connector +import sqlparse # type: ignore[import-untyped] +from mysql.connector import errors +from mysql.connector.charsets import MYSQL_CHARACTER_SETS +from pypika_tortoise import MySQLQuery +from tortoise.backends.base.client import (Capabilities, ConnectionWrapper, + NestedTransactionContext, + TransactionContext) +from tortoise.exceptions import (DBConnectionError, IntegrityError, + OperationalError, TransactionManagementError) + +from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError +from aws_advanced_python_wrapper.tortoise_orm.async_support.async_connection_pool import ( + AsyncConnectionPool, PoolConfig) +from aws_advanced_python_wrapper.tortoise_orm.async_support.async_wrapper import ( + AwsConnectionAsyncWrapper, AwsWrapperAsyncConnector) +from aws_advanced_python_wrapper.tortoise_orm.backends.base.client import ( + AwsBaseDBAsyncClient, AwsTransactionalDBClient, + TortoiseAwsClientPooledConnectionWrapper, + TortoiseAwsClientPooledTransactionContext) +from aws_advanced_python_wrapper.tortoise_orm.backends.mysql.executor import \ + AwsMySQLExecutor +from aws_advanced_python_wrapper.tortoise_orm.backends.mysql.schema_generator import \ + AwsMySQLSchemaGenerator +from aws_advanced_python_wrapper.utils.log import Logger + +logger = Logger(__name__) +T = TypeVar("T") +FuncType = Callable[..., Coroutine[None, None, T]] + + +def translate_exceptions(func: FuncType) -> FuncType: + """Decorator to translate MySQL connector exceptions to Tortoise exceptions.""" + @wraps(func) + async def translate_exceptions_(self, *args) -> T: + try: + try: + return await func(self, *args) + except AwsWrapperError as aws_err: # Unwrap any AwsWrappedErrors + if aws_err.__cause__: + raise aws_err.__cause__ + raise + except FailoverError: # Raise any failover errors + raise + except errors.IntegrityError as exc: + raise IntegrityError(exc) + except ( + errors.OperationalError, + errors.ProgrammingError, + errors.DataError, + errors.InternalError, + errors.NotSupportedError, + errors.DatabaseError + ) as exc: + raise OperationalError(exc) + + return translate_exceptions_ + + +class AwsMySQLClient(AwsBaseDBAsyncClient): + """AWS Advanced Python Wrapper MySQL client for Tortoise ORM.""" + query_class = MySQLQuery + executor_class = AwsMySQLExecutor + schema_generator = AwsMySQLSchemaGenerator + capabilities = Capabilities( + dialect="mysql", + requires_limit=True, + inline_comment=True, + support_index_hint=True, + support_for_posix_regex_queries=True, + support_json_attributes=True, + ) + + def __init__( + self, + *, + user: str, + password: str, + database: str, + host: str, + port: SupportsInt, + **kwargs, + ): + """Initialize AWS MySQL client with connection parameters.""" + super().__init__(**kwargs) + + # Basic connection parameters + self.user = user + self.password = password + self.database = database + self.host = host + self.port = int(port) + self.extra = kwargs.copy() + + # Extract MySQL-specific settings + self.storage_engine = self.extra.pop("storage_engine", "innodb") + self.charset = self.extra.pop("charset", "utf8mb4") + + # Remove Tortoise-specific parameters + self.extra.pop("connection_name", None) + self.extra.pop("fetch_inserted", None) + self.extra.pop("autocommit", None) + self.extra.setdefault("sql_mode", "STRICT_TRANS_TABLES") + + # Initialize connection templates + self._init_connection_templates() + + # Initialize state + self._template: Dict[str, Any] = {} + self._connection = None + self._pool_init_lock: asyncio.Lock = asyncio.Lock() + self._pool: Optional[AsyncConnectionPool] = None + + # Pool configuration + default_pool_config = {field.name: field.default for field in fields(PoolConfig)} + self._pool_config = PoolConfig( + min_size=self.extra.pop("min_size", default_pool_config["min_size"]), + max_size=self.extra.pop("max_size", default_pool_config["max_size"]), + acquire_conn_timeout=self.extra.pop("acquire_conn_timeout", default_pool_config["acquire_conn_timeout"]), + max_conn_lifetime=self.extra.pop("max_conn_lifetime", default_pool_config["max_conn_lifetime"]), + max_conn_idle_time=self.extra.pop("max_conn_idle_time", default_pool_config["max_conn_idle_time"]), + health_check_interval=self.extra.pop("health_check_interval", default_pool_config["health_check_interval"]), + pre_ping=self.extra.pop("pre_ping", default_pool_config["pre_ping"]) + ) + + def _init_connection_templates(self) -> None: + """Initialize connection templates for with/without database.""" + base_template = { + "user": self.user, + "password": self.password, + "host": self.host, + "port": self.port, + "autocommit": True, + **self.extra + } + + self._template_with_db = {**base_template, "database": self.database} + self._template_no_db = {**base_template, "database": None} + + async def _init_pool(self) -> None: + """Initialize the connection pool.""" + if self._pool is not None: + return + + async def create_connection(): + return await AwsWrapperAsyncConnector.connect_with_aws_wrapper(mysql.connector.Connect, **self._template) + + self._pool = AsyncConnectionPool( + creator=create_connection, + config=self._pool_config + ) + await self._pool.initialize() + + # Connection Management + async def create_connection(self, with_db: bool) -> None: + """Initialize connection pool and configure database settings.""" + # Validate charset + if self.charset.lower() not in [cs[0] for cs in MYSQL_CHARACTER_SETS if cs is not None]: + raise DBConnectionError(f"Unknown character set: {self.charset}") + + # Set transaction support based on storage engine + if self.storage_engine.lower() != "innodb": + self.capabilities.__dict__["supports_transactions"] = False + + # Set template based on database requirement + self._template = self._template_with_db if with_db else self._template_no_db + + await self._init_pool() + + def _disable_pool_for_testing(self) -> None: + """Disable pool initialization for unit testing.""" + self._pool = None + + async def _no_op(): + pass + self._init_pool = _no_op # type: ignore[method-assign] + + async def close(self) -> None: + """Close connections - AWS wrapper handles cleanup internally.""" + if hasattr(self, '_pool') and self._pool: + await self._pool.close() + self._pool = None + + def acquire_connection(self): + """Acquire a connection from the pool.""" + return TortoiseAwsClientPooledConnectionWrapper( + self, pool_init_lock=self._pool_init_lock + ) + + # Database Operations + async def db_create(self) -> None: + """Create the database.""" + await self.create_connection(with_db=False) + await self.execute_script(f"CREATE DATABASE {self.database};") + await self.close() + + async def db_delete(self) -> None: + """Delete the database.""" + await self.create_connection(with_db=False) + await self.execute_script(f"DROP DATABASE {self.database};") + await self.close() + + # Query Execution Methods + @translate_exceptions + async def execute_insert(self, query: str, values: List[Any]) -> int: + """Execute an INSERT query and return the last inserted row ID.""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + await cursor.execute(query, values) + return cursor.lastrowid + + @translate_exceptions + async def execute_many(self, query: str, values: List[List[Any]]) -> None: + """Execute a query with multiple parameter sets.""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + if self.capabilities.supports_transactions: + await self._execute_many_with_transaction(cursor, connection, query, values) + else: + await cursor.executemany(query, values) + + async def _execute_many_with_transaction(self, cursor: Any, connection: Any, query: str, values: List[List[Any]]) -> None: + """Execute many queries within a transaction.""" + try: + await connection.set_autocommit(False) + try: + await cursor.executemany(query, values) + except Exception: + await connection.rollback() + raise + else: + await connection.commit() + finally: + await connection.set_autocommit(True) + + @translate_exceptions + async def execute_query(self, query: str, values: Optional[List[Any]] = None) -> Tuple[int, List[Dict[str, Any]]]: + """Execute a query and return row count and results.""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + await cursor.execute(query, values) + rows = await cursor.fetchall() + if rows: + fields = [desc[0] for desc in cursor.description] + return cursor.rowcount, [dict(zip(fields, row)) for row in rows] + return cursor.rowcount, [] + + async def execute_query_dict(self, query: str, values: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + """Execute a query and return only the results as dictionaries.""" + return (await self.execute_query(query, values))[1] + + @translate_exceptions + async def execute_script(self, query: str) -> None: + """Execute a script query.""" + async with self.acquire_connection() as connection: + logger.debug(f"Executing script: {query}") + async with connection.cursor() as cursor: + # Parse multi-statement queries since MySQL Connector doesn't handle them well + statements = sqlparse.split(query) + for statement in statements: + statement = statement.strip() + if statement: + await cursor.execute(statement) + + # Transaction Support + def _in_transaction(self) -> TransactionContext: + """Create a new transaction context.""" + return TortoiseAwsClientPooledTransactionContext(TransactionWrapper(self), self._pool_init_lock) + + +class TransactionWrapper(AwsMySQLClient, AwsTransactionalDBClient): + """Transaction wrapper for AWS MySQL client.""" + + def __init__(self, connection: AwsMySQLClient) -> None: + self.connection_name = connection.connection_name + self._connection: AwsConnectionAsyncWrapper = connection._connection + + self._lock = asyncio.Lock() + self._savepoint: Optional[str] = None + self._finalized: bool = False + self._parent = connection + + def _in_transaction(self) -> TransactionContext: + """Create a nested transaction context.""" + return NestedTransactionContext(TransactionWrapper(self)) + + def acquire_connection(self): + """Acquire the transaction connection.""" + return ConnectionWrapper(self._lock, self) + + # Transaction Control Methods + @translate_exceptions + async def begin(self) -> None: + """Begin the transaction.""" + await self._connection.set_autocommit(False) + self._finalized = False + + async def commit(self) -> None: + """Commit the transaction.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + await self._connection.commit() + await self._connection.set_autocommit(True) + self._finalized = True + + async def rollback(self) -> None: + """Rollback the transaction.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + await self._connection.rollback() + await self._connection.set_autocommit(True) + self._finalized = True + + # Savepoint Management + @translate_exceptions + async def savepoint(self) -> None: + """Create a savepoint.""" + self._savepoint = _gen_savepoint_name() + async with self._connection.cursor() as cursor: + await cursor.execute(f"SAVEPOINT {self._savepoint}") + + async def savepoint_rollback(self): + """Rollback to the savepoint.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + async with self._connection.cursor() as cursor: + await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint}") + self._savepoint = None + self._finalized = True + + async def release_savepoint(self): + """Release the savepoint.""" + if self._finalized: + raise TransactionManagementError("Transaction already finalized") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to release") + async with self._connection.cursor() as cursor: + await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint}") + self._savepoint = None + self._finalized = True + + @translate_exceptions + async def execute_many(self, query: str, values: List[List[Any]]) -> None: + """Execute many queries without autocommit handling (already in transaction).""" + async with self.acquire_connection() as connection: + logger.debug(f"{query}: {values}") + async with connection.cursor() as cursor: + await cursor.executemany(query, values) + + +def _gen_savepoint_name(_c: count = count()) -> str: + """Generate a unique savepoint name.""" + return f"tortoise_savepoint_{next(_c)}" diff --git a/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/executor.py b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/executor.py new file mode 100644 index 00000000..b5762337 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/executor.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from aws_advanced_python_wrapper.tortoise_orm.utils import load_mysql_module + +MySQLExecutor: Type = load_mysql_module("executor.py", "MySQLExecutor") + + +class AwsMySQLExecutor(MySQLExecutor): + """AWS MySQL Executor for Tortoise ORM.""" + pass diff --git a/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/schema_generator.py b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/schema_generator.py new file mode 100644 index 00000000..df2386bc --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/backends/mysql/schema_generator.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from aws_advanced_python_wrapper.tortoise_orm.utils import load_mysql_module + +MySQLSchemaGenerator: Type = load_mysql_module("schema_generator.py", "MySQLSchemaGenerator") + + +class AwsMySQLSchemaGenerator(MySQLSchemaGenerator): + """AWS MySQL Executor for Tortoise ORM.""" + pass diff --git a/aws_advanced_python_wrapper/tortoise_orm/utils.py b/aws_advanced_python_wrapper/tortoise_orm/utils.py new file mode 100644 index 00000000..90ecede0 --- /dev/null +++ b/aws_advanced_python_wrapper/tortoise_orm/utils.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib.util +import sys +from pathlib import Path +from typing import Type + + +def load_mysql_module(module_file: str, class_name: str) -> Type: + """Load MySQL backend module without __init__.py.""" + import tortoise + tortoise_path = Path(tortoise.__file__).parent + module_path = tortoise_path / "backends" / "mysql" / module_file + module_name = f"tortoise.backends.mysql.{module_file[:-3]}" + + if module_name not in sys.modules: + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module {module_name}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + return getattr(sys.modules[module_name], class_name) diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index e04d293f..beafa307 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -265,6 +265,13 @@ def rowcount(self) -> int: def arraysize(self) -> int: return self.target_cursor.arraysize + # Optional for PEP249 + @property + def lastrowid(self) -> int: + if hasattr(self.target_cursor, 'lastrowid'): + return self.target_cursor.lastrowid + raise AttributeError("'Cursor' object has no attribute 'lastrowid'") + def close(self) -> None: self._plugin_manager.execute(self.target_cursor, "Cursor.close", lambda: self.target_cursor.close()) diff --git a/poetry.lock b/poetry.lock index e589219b..ff600b2a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,21 @@ # This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +[[package]] +name = "aiosqlite" +version = "0.22.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiosqlite-0.22.0-py3-none-any.whl", hash = "sha256:96007fac2ce70eda3ca1bba7a3008c435258a592b8fbf2ee3eeaa36d33971a09"}, + {file = "aiosqlite-0.22.0.tar.gz", hash = "sha256:7e9e52d72b319fcdeac727668975056c49720c995176dc57370935e5ba162bb9"}, +] + +[package.extras] +dev = ["attribution (==1.8.0)", "black (==25.11.0)", "build (>=1.2)", "coverage[toml] (==7.10.7)", "flake8 (==7.3.0)", "flake8-bugbear (==24.12.12)", "flit (==3.12.0)", "mypy (==1.19.0)", "ufmt (==2.8.0)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.2)"] + [[package]] name = "aws-xray-sdk" version = "2.15.0" @@ -693,6 +709,18 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "iso8601" +version = "2.1.0" +description = "Simple module to parse ISO 8601 dates" +optional = false +python-versions = ">=3.7,<4.0" +groups = ["main"] +files = [ + {file = "iso8601-2.1.0-py3-none-any.whl", hash = "sha256:aac4145c4dcb66ad8b648a02830f5e2ff6c24af20f4f482689be402db2429242"}, + {file = "iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df"}, +] + [[package]] name = "isort" version = "5.13.2" @@ -1290,6 +1318,18 @@ files = [ {file = "pyflakes-3.1.0.tar.gz", hash = "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc"}, ] +[[package]] +name = "pypika-tortoise" +version = "0.6.2" +description = "Forked from pypika and streamline just for tortoise-orm" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pypika_tortoise-0.6.2-py3-none-any.whl", hash = "sha256:425462b02ede0a5ed7b812ec12427419927ed6b19282c55667d1cbc9a440d3cb"}, + {file = "pypika_tortoise-0.6.2.tar.gz", hash = "sha256:f95ab59d9b6454db2e8daa0934728458350a1f3d56e81d9d1debc8eebeff26b3"}, +] + [[package]] name = "pytest" version = "7.4.4" @@ -1313,6 +1353,25 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +groups = ["test"] +files = [ + {file = "pytest_asyncio-0.21.2-py3-none-any.whl", hash = "sha256:ab664c88bb7998f711d8039cacd4884da6430886ae8bbd4eded552ed2004f16b"}, + {file = "pytest_asyncio-0.21.2.tar.gz", hash = "sha256:d67738fc232b94b326b9d060750beb16e0074210b98dd8b58a5239fa2a154f45"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-html" version = "4.1.1" @@ -1401,6 +1460,18 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytz" +version = "2025.2" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, + {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, +] + [[package]] name = "requests" version = "2.32.4" @@ -1573,6 +1644,23 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlparse" +version = "0.4.4" +description = "A non-validating SQL parser." +optional = false +python-versions = ">=3.5" +groups = ["main"] +files = [ + {file = "sqlparse-0.4.4-py3-none-any.whl", hash = "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3"}, + {file = "sqlparse-0.4.4.tar.gz", hash = "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"}, +] + +[package.extras] +dev = ["build", "flake8"] +doc = ["sphinx"] +test = ["pytest", "pytest-cov"] + [[package]] name = "tabulate" version = "0.9.0" @@ -1613,6 +1701,32 @@ files = [ {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, ] +[[package]] +name = "tortoise-orm" +version = "0.25.1" +description = "Easy async ORM for python, built with relations in mind" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "tortoise_orm-0.25.1-py3-none-any.whl", hash = "sha256:df0ef7e06eb0650a7e5074399a51ee6e532043308c612db2cac3882486a3fd9f"}, + {file = "tortoise_orm-0.25.1.tar.gz", hash = "sha256:4d5bfd13d5750935ffe636a6b25597c5c8f51c47e5b72d7509d712eda1a239fe"}, +] + +[package.dependencies] +aiosqlite = ">=0.16.0,<1.0.0" +iso8601 = {version = ">=2.1.0,<3.0.0", markers = "python_version < \"4.0\""} +pypika-tortoise = {version = ">=0.6.1,<1.0.0", markers = "python_version < \"4.0\""} +pytz = "*" + +[package.extras] +accel = ["ciso8601 ; sys_platform != \"win32\" and implementation_name == \"cpython\"", "orjson", "uvloop ; sys_platform != \"win32\" and implementation_name == \"cpython\""] +aiomysql = ["aiomysql"] +asyncmy = ["asyncmy (>=0.2.8,<1.0.0) ; python_version < \"4.0\""] +asyncodbc = ["asyncodbc (>=0.1.1,<1.0.0) ; python_version < \"4.0\""] +asyncpg = ["asyncpg"] +psycopg = ["psycopg[binary,pool] (>=3.0.12,<4.0.0)"] + [[package]] name = "toxiproxy-python" version = "0.1.1" @@ -2240,4 +2354,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.10.0" -content-hash = "219c6ed169c74778a600c5881dc9b14885a8b15c041e76f4f557f6538f7302d3" +content-hash = "1938c9deeacfb1c241e94260edc2f827afc4db16dcdfce392dbd2480bd580481" diff --git a/pyproject.toml b/pyproject.toml index 6b5e15f1..24503048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ types_aws_xray_sdk = "^2.13.0" opentelemetry-api = "^1.22.0" opentelemetry-sdk = "^1.22.0" requests = "^2.32.2" +tortoise-orm = "^0.25.1" +sqlparse = "^0.4.4" [tool.poetry.group.dev.dependencies] mypy = "^1.9.0" @@ -63,6 +65,7 @@ mysql-connector-python = "^9.5.0" opentelemetry-exporter-otlp = "^1.22.0" opentelemetry-exporter-otlp-proto-grpc = "^1.22.0" opentelemetry-sdk-extension-aws = "^2.0.1" +pytest-asyncio = "^0.21.0" [tool.isort] sections = "FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER" @@ -77,4 +80,3 @@ filterwarnings = [ 'ignore:cache could not write path', 'ignore:could not create cache path' ] - diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index 2e23eeba..ed3dd8ff 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -62,14 +62,17 @@ def conn_utils(): def pytest_runtest_setup(item): test_name: Optional[str] = None + full_test_name = item.nodeid # Full test path including class and method + if hasattr(item, "callspec"): current_driver = item.callspec.params.get("test_driver") TestEnvironment.get_current().set_current_driver(current_driver) - test_name = item.callspec.id + test_name = f"{item.name}[{item.callspec.id}]" else: TestEnvironment.get_current().set_current_driver(None) + test_name = item.name - logger.info("Starting test preparation for: " + test_name) + logger.info(f"Starting test preparation for: {test_name} (full: {full_test_name})") segment: Optional[Segment] = None if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features(): @@ -80,6 +83,7 @@ def pytest_runtest_setup(item): .get_info().get_request().get_target_python_version().name) if test_name is not None: segment.put_annotation("test_name", test_name) + segment.put_annotation("full_test_name", full_test_name) info = TestEnvironment.get_current().get_info() request = info.get_request() @@ -107,7 +111,11 @@ def pytest_runtest_setup(item): logger.warning("conftest.ExceptionWhileObtainingInstanceIDs", ex) instances = list() - sleep(5) + # Only sleep if we still need to retry + if (len(instances) < request.get_num_of_instances() + or len(instances) == 0 + or not rds_utility.is_db_instance_writer(instances[0])): + sleep(5) assert len(instances) > 0 current_writer = instances[0] diff --git a/tests/integration/container/test_iam_authentication.py b/tests/integration/container/test_iam_authentication.py index 4bc6ee3d..f0bfa8ec 100644 --- a/tests/integration/container/test_iam_authentication.py +++ b/tests/integration/container/test_iam_authentication.py @@ -68,7 +68,7 @@ def test_iam_wrong_database_username(self, test_environment: TestEnvironment, params = conn_utils.get_connect_params(user=user) params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver - with pytest.raises(AwsWrapperError): + with pytest.raises(Exception): AwsWrapperConnection.connect( target_driver_connect, **params, diff --git a/tests/integration/container/tortoise/__init__.py b/tests/integration/container/tortoise/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/tests/integration/container/tortoise/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/container/tortoise/models/__init__.py b/tests/integration/container/tortoise/models/__init__.py new file mode 100644 index 00000000..12de791f --- /dev/null +++ b/tests/integration/container/tortoise/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/tests/integration/container/tortoise/models/test_models.py b/tests/integration/container/tortoise/models/test_models.py new file mode 100644 index 00000000..b7d3ffc5 --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise import fields +from tortoise.models import Model + + +class User(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "users" + + +class UniqueName(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=20, null=True, unique=True) + optional = fields.CharField(max_length=20, null=True) + other_optional = fields.CharField(max_length=20, null=True) + + class Meta: + table = "unique_names" + + +class TableWithSleepTrigger(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + value = fields.CharField(max_length=100) + + class Meta: + table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_copy.py b/tests/integration/container/tortoise/models/test_models_copy.py new file mode 100644 index 00000000..b7d3ffc5 --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models_copy.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tortoise import fields +from tortoise.models import Model + + +class User(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "users" + + +class UniqueName(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=20, null=True, unique=True) + optional = fields.CharField(max_length=20, null=True) + other_optional = fields.CharField(max_length=20, null=True) + + class Meta: + table = "unique_names" + + +class TableWithSleepTrigger(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=50) + value = fields.CharField(max_length=100) + + class Meta: + table = "table_with_sleep_trigger" diff --git a/tests/integration/container/tortoise/models/test_models_relationships.py b/tests/integration/container/tortoise/models/test_models_relationships.py new file mode 100644 index 00000000..16e2bfe6 --- /dev/null +++ b/tests/integration/container/tortoise/models/test_models_relationships.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from tortoise import fields +from tortoise.models import Model + +if TYPE_CHECKING: + from tortoise.fields.relational import (ForeignKeyFieldInstance, + OneToOneFieldInstance) + +# One-to-One Relationship Models + + +class RelTestAccount(Model): + id = fields.IntField(primary_key=True) + username = fields.CharField(max_length=50, unique=True) + email = fields.CharField(max_length=100) + + class Meta: + table = "rel_test_accounts" + + +class RelTestAccountProfile(Model): + id = fields.IntField(primary_key=True) + account: "OneToOneFieldInstance" = fields.OneToOneField("models.RelTestAccount", related_name="profile", on_delete=fields.CASCADE) + bio = fields.TextField(null=True) + avatar_url = fields.CharField(max_length=200, null=True) + + class Meta: + table = "rel_test_account_profiles" + + +# One-to-Many Relationship Models +class RelTestPublisher(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + email = fields.CharField(max_length=100, unique=True) + + class Meta: + table = "rel_test_publishers" + + +class RelTestPublication(Model): + id = fields.IntField(primary_key=True) + title = fields.CharField(max_length=200) + isbn = fields.CharField(max_length=13, unique=True) + publisher: "ForeignKeyFieldInstance" = fields.ForeignKeyField("models.RelTestPublisher", related_name="publications", on_delete=fields.CASCADE) + published_date = fields.DateField(null=True) + + class Meta: + table = "rel_test_publications" + + +# Many-to-Many Relationship Models +class RelTestLearner(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + learner_id = fields.CharField(max_length=20, unique=True) + + class Meta: + table = "rel_test_learners" + + +class RelTestSubject(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + code = fields.CharField(max_length=10, unique=True) + credits = fields.IntField() + learners = fields.ManyToManyField("models.RelTestLearner", related_name="subjects") + + class Meta: + table = "rel_test_subjects" diff --git a/tests/integration/container/tortoise/router/__init__.py b/tests/integration/container/tortoise/router/__init__.py new file mode 100644 index 00000000..12de791f --- /dev/null +++ b/tests/integration/container/tortoise/router/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/tests/integration/container/tortoise/router/test_router.py b/tests/integration/container/tortoise/router/test_router.py new file mode 100644 index 00000000..e4693120 --- /dev/null +++ b/tests/integration/container/tortoise/router/test_router.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +class TestRouter: + def db_for_read(self, model): + return "default" + + def db_for_write(self, model): + return "default" diff --git a/tests/integration/container/tortoise/test_tortoise_basic.py b/tests/integration/container/tortoise/test_tortoise_basic.py new file mode 100644 index 00000000..d06f7de2 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_basic.py @@ -0,0 +1,261 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest +import pytest_asyncio +from tortoise.transactions import atomic, in_transaction + +from tests.integration.container.tortoise.models.test_models import ( + UniqueName, User) +from tests.integration.container.tortoise.test_tortoise_common import \ + setup_tortoise +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseBasic: + """ + Test class for Tortoise ORM integration with AWS Advanced Python Wrapper. + Contains tests related to basic test operations. + """ + + @pytest_asyncio.fixture + async def setup_tortoise_basic(self, conn_utils): + """Setup Tortoise with default plugins.""" + async for result in setup_tortoise(conn_utils): + yield result + + @pytest.mark.asyncio + async def test_basic_crud_operations(self, setup_tortoise_basic): + """Test basic CRUD operations with AWS wrapper.""" + # Create + user = await User.create(name="John Doe", email="john@example.com") + assert user.id is not None + assert user.name == "John Doe" + + # Read + found_user = await User.get(id=user.id) + assert found_user.name == "John Doe" + assert found_user.email == "john@example.com" + + # Update + found_user.name = "Jane Doe" + await found_user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Jane Doe" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_basic_crud_operations_config(self, setup_tortoise_basic): + """Test basic CRUD operations with AWS wrapper.""" + # Create + user = await User.create(name="John Doe", email="john@example.com") + assert user.id is not None + assert user.name == "John Doe" + + # Read + found_user = await User.get(id=user.id) + assert found_user.name == "John Doe" + assert found_user.email == "john@example.com" + + # Update + found_user.name = "Jane Doe" + await found_user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Jane Doe" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @pytest.mark.asyncio + async def test_transaction_support(self, setup_tortoise_basic): + """Test transaction handling with AWS wrapper.""" + async with in_transaction() as conn: + await User.create(name="User 1", email="user1@example.com", using_db=conn) + await User.create(name="User 2", email="user2@example.com", using_db=conn) + + # Verify users exist within transaction + users = await User.filter(name__in=["User 1", "User 2"]).using_db(conn) + assert len(users) == 2 + + @pytest.mark.asyncio + async def test_transaction_rollback(self, setup_tortoise_basic): + """Test transaction rollback with AWS wrapper.""" + try: + async with in_transaction() as conn: + await User.create(name="Test User", email="test@example.com", using_db=conn) + # Force rollback by raising exception + raise ValueError("Test rollback") + except ValueError: + pass + + # Verify user was not created due to rollback + users = await User.filter(name="Test User") + assert len(users) == 0 + + @pytest.mark.asyncio + async def test_bulk_operations(self, setup_tortoise_basic): + """Test bulk operations with AWS wrapper.""" + # Bulk create + users_data = [ + {"name": f"User {i}", "email": f"user{i}@example.com"} + for i in range(5) + ] + await User.bulk_create([User(**data) for data in users_data]) + + # Verify bulk creation + users = await User.filter(name__startswith="User") + assert len(users) == 5 + + # Bulk update + await User.filter(name__startswith="User").update(name="Updated User") + + updated_users = await User.filter(name="Updated User") + assert len(updated_users) == 5 + + @pytest.mark.asyncio + async def test_query_operations(self, setup_tortoise_basic): + """Test various query operations with AWS wrapper.""" + # Setup test data + await User.create(name="Alice", email="alice@example.com") + await User.create(name="Bob", email="bob@example.com") + await User.create(name="Charlie", email="charlie@example.com") + + # Test filtering + alice = await User.get(name="Alice") + assert alice.email == "alice@example.com" + + # Test count + count = await User.all().count() + assert count >= 3 + + # Test ordering + users = await User.all().order_by("name") + assert users[0].name == "Alice" + + # Test exists + exists = await User.filter(name="Alice").exists() + assert exists is True + + # Test values + emails = await User.all().values_list("email", flat=True) + assert "alice@example.com" in emails + + @pytest.mark.asyncio + async def test_bulk_create_with_ids(self, setup_tortoise_basic): + """Test bulk create operations with ID verification.""" + # Bulk create 1000 UniqueName objects with no name (null values) + await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) + + # Get all created records with id and name + all_ = await UniqueName.all().values("id", "name") + + # Get the starting ID + inc = all_[0]["id"] + + # Sort by ID for comparison + all_sorted = sorted(all_, key=lambda x: x["id"]) + + # Verify the IDs are sequential and names are None + expected = [{"id": val + inc, "name": None} for val in range(1000)] + + assert len(all_sorted) == 1000 + assert all_sorted == expected + + @pytest.mark.asyncio + async def test_concurrency_read(self, setup_tortoise_basic): + """Test concurrent read operations with AWS wrapper.""" + + await User.create(name="Test User", email="test@example.com") + user1 = await User.first() + + # Perform 100 concurrent reads + all_read = await asyncio.gather(*[User.first() for _ in range(100)]) + + # All reads should return the same user + assert all_read == [user1 for _ in range(100)] + + @pytest.mark.asyncio + async def test_concurrency_create(self, setup_tortoise_basic): + """Test concurrent create operations with AWS wrapper.""" + + # Perform 100 concurrent creates with unique emails + all_write = await asyncio.gather(*[ + User.create(name="Test", email=f"test{i}@example.com") + for i in range(100) + ]) + + # Read all created users + all_read = await User.all() + + # All created users should exist in the database + assert set(all_write) == set(all_read) + + @pytest.mark.asyncio + async def test_atomic_decorator(self, setup_tortoise_basic): + """Test atomic decorator for transaction handling with AWS wrapper.""" + + @atomic() + async def create_users_atomically(): + user1 = await User.create(name="Atomic User 1", email="atomic1@example.com") + user2 = await User.create(name="Atomic User 2", email="atomic2@example.com") + return user1, user2 + + # Execute atomic operation + user1, user2 = await create_users_atomically() + + # Verify both users were created + assert user1.id is not None + assert user2.id is not None + + # Verify users exist in database + found_users = await User.filter(name__startswith="Atomic User") + assert len(found_users) == 2 + + @pytest.mark.asyncio + async def test_atomic_decorator_rollback(self, setup_tortoise_basic): + """Test atomic decorator rollback on exception with AWS wrapper.""" + + @atomic() + async def create_users_with_error(): + await User.create(name="Atomic Rollback User", email="rollback@example.com") + # Force rollback by raising exception + raise ValueError("Intentional error for rollback test") + + # Execute atomic operation that should fail + with pytest.raises(ValueError): + await create_users_with_error() + + # Verify user was not created due to rollback + users = await User.filter(name="Atomic Rollback User") + assert len(users) == 0 diff --git a/tests/integration/container/tortoise/test_tortoise_common.py b/tests/integration/container/tortoise/test_tortoise_common.py new file mode 100644 index 00000000..b3528091 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_common.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from tortoise import Tortoise, connections + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise_orm # noqa: F401 +from tests.integration.container.tortoise.models.test_models import User +from tests.integration.container.utils.test_environment import TestEnvironment + + +async def clear_test_models(): + """Clear all test models by calling .all().delete() on each.""" + from tortoise.models import Model + + import tests.integration.container.tortoise.models.test_models as models_module + + for attr_name in dir(models_module): + attr = getattr(models_module, attr_name) + if isinstance(attr, type) and issubclass(attr, Model) and attr != Model: + await attr.all().delete() + + +async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwargs): + """Setup Tortoise with AWS MySQL backend and configurable plugins.""" + db_url = conn_utils.get_aws_tortoise_url( + TestEnvironment.get_current().get_engine(), + plugins=plugins, + **kwargs, + ) + config = { + "connections": { + "default": db_url + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + } + } + + await Tortoise.init(config=config) + + await Tortoise.generate_schemas() + await clear_test_models() + + yield + + await reset_tortoise() + + +async def run_basic_read_operations(name_prefix="Test", email_prefix="test"): + """Common test logic for basic read operations.""" + user = await User.create(name=f"{name_prefix} User", email=f"{email_prefix}@example.com") + + found_user = await User.get(id=user.id) + assert found_user.name == f"{name_prefix} User" + assert found_user.email == f"{email_prefix}@example.com" + + users = await User.filter(name=f"{name_prefix} User") + assert len(users) == 1 + assert users[0].id == user.id + + +async def run_basic_write_operations(name_prefix="Write", email_prefix="write"): + """Common test logic for basic write operations.""" + user = await User.create(name=f"{name_prefix} Test", email=f"{email_prefix}@example.com") + assert user.id is not None + + user.name = f"Updated {name_prefix}" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == f"Updated {name_prefix}" + + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + +async def reset_tortoise(): + await Tortoise.close_connections() + await Tortoise._reset_apps() + Tortoise._inited = False + Tortoise.apps = {} + connections._db_config = {} diff --git a/tests/integration/container/tortoise/test_tortoise_config.py b/tests/integration/container/tortoise/test_tortoise_config.py new file mode 100644 index 00000000..761206da --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_config.py @@ -0,0 +1,292 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +from tortoise import Tortoise, connections + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise_orm # noqa: F401 +from tests.integration.container.tortoise.models.test_models import User +from tests.integration.container.tortoise.test_tortoise_common import \ + reset_tortoise +from tests.integration.container.utils.conditions import ( + disable_on_deployments, disable_on_engines, disable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseConfig: + """Test class for Tortoise ORM configuration scenarios.""" + + async def _clear_all_test_models(self): + """Clear all test models for a specific connection.""" + await User.all().delete() + + @pytest_asyncio.fixture + async def setup_tortoise_dict_config(self, conn_utils): + """Setup Tortoise with dictionary configuration instead of URL.""" + # Ensure clean state + host = conn_utils.writer_cluster_host if conn_utils.writer_cluster_host else conn_utils.writer_host + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise_orm.backends.mysql", + "credentials": { + "host": host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": conn_utils.dbname, + "plugins": "aurora_connection_tracker", + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + } + } + + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + await self._clear_all_test_models() + + yield + + await self._clear_all_test_models() + await reset_tortoise() + + @pytest_asyncio.fixture + async def setup_tortoise_multi_db(self, conn_utils): + """Setup Tortoise with two different databases using same backend.""" + # Create second database name + original_db = conn_utils.dbname + second_db = f"{original_db}_test2" + host = conn_utils.writer_cluster_host if conn_utils.writer_cluster_host else conn_utils.writer_host + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise_orm.backends.mysql", + "credentials": { + "host": host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": original_db, + "plugins": "aurora_connection_tracker", + } + }, + "second_db": { + "engine": "aws_advanced_python_wrapper.tortoise_orm.backends.mysql", + "credentials": { + "host": host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": second_db, + "plugins": "aurora_connection_tracker" + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + }, + "models2": { + "models": ["tests.integration.container.tortoise.models.test_models_copy"], + "default_connection": "second_db", + } + } + } + + await Tortoise.init(config=config) + + # Create second database + conn = connections.get("second_db") + try: + await conn.db_create() + except Exception: + pass + + await Tortoise.generate_schemas() + + await self._clear_all_test_models() + + yield second_db + await self._clear_all_test_models() + + # Drop second database + conn = connections.get("second_db") + await conn.db_delete() + await reset_tortoise() + + @pytest_asyncio.fixture + async def setup_tortoise_with_router(self, conn_utils): + """Setup Tortoise with router configuration.""" + host = conn_utils.writer_cluster_host if conn_utils.writer_cluster_host else conn_utils.writer_host + config = { + "connections": { + "default": { + "engine": "aws_advanced_python_wrapper.tortoise_orm.backends.mysql", + "credentials": { + "host": host, + "port": conn_utils.port, + "user": conn_utils.user, + "password": conn_utils.password, + "database": conn_utils.dbname, + "plugins": "aurora_connection_tracker" + } + } + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models"], + "default_connection": "default", + } + }, + "routers": ["tests.integration.container.tortoise.router.test_router.TestRouter"] + } + + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + await self._clear_all_test_models() + + yield + + await self._clear_all_test_models() + await reset_tortoise() + + @pytest.mark.asyncio + async def test_dict_config_read_operations(self, setup_tortoise_dict_config): + """Test basic read operations with dictionary configuration.""" + # Create test data + user = await User.create(name="Dict Config User", email="dict@example.com") + + # Read operations + found_user = await User.get(id=user.id) + assert found_user.name == "Dict Config User" + assert found_user.email == "dict@example.com" + + # Query operations + users = await User.filter(name="Dict Config User") + assert len(users) == 1 + assert users[0].id == user.id + + @pytest.mark.asyncio + async def test_dict_config_write_operations(self, setup_tortoise_dict_config): + """Test basic write operations with dictionary configuration.""" + # Create + user = await User.create(name="Dict Write Test", email="dictwrite@example.com") + assert user.id is not None + + # Update + user.name = "Updated Dict User" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Updated Dict User" + + # Delete + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) + + @disable_on_deployments([DatabaseEngineDeployment.DOCKER]) + @pytest.mark.asyncio + async def test_multi_db_operations(self, setup_tortoise_multi_db): + """Test operations with multiple databases using same backend.""" + # Get second connection + second_conn = connections.get("second_db") + + # Create users in different databases + user1 = await User.create(name="DB1 User", email="db1@example.com") + user2 = await User.create(name="DB2 User", email="db2@example.com", using_db=second_conn) + + # Verify users exist in their respective databases + db1_users = await User.all() + db2_users = await User.all().using_db(second_conn) + + assert len(db1_users) == 1 + assert len(db2_users) == 1 + assert db1_users[0].name == "DB1 User" + assert db2_users[0].name == "DB2 User" + + # Verify isolation - users don't exist in the other database + db1_user_in_db2 = await User.filter(name="DB1 User").using_db(second_conn) + db2_user_in_db1 = await User.filter(name="DB2 User") + + assert len(db1_user_in_db2) == 0 + assert len(db2_user_in_db1) == 0 + + # Test updates in different databases + user1.name = "Updated DB1 User" + await user1.save() + + user2.name = "Updated DB2 User" + await user2.save(using_db=second_conn) + + # Verify updates + updated_user1 = await User.get(id=user1.id) + updated_user2 = await User.get(id=user2.id, using_db=second_conn) + + assert updated_user1.name == "Updated DB1 User" + assert updated_user2.name == "Updated DB2 User" + + @pytest.mark.asyncio + async def test_router_read_operations(self, setup_tortoise_with_router): + """Test read operations with router configuration.""" + # Create test data + user = await User.create(name="Router User", email="router@example.com") + + # Read operations (should be routed by router) + found_user = await User.get(id=user.id) + assert found_user.name == "Router User" + assert found_user.email == "router@example.com" + + # Query operations + users = await User.filter(name="Router User") + assert len(users) == 1 + assert users[0].id == user.id + + @pytest.mark.asyncio + async def test_router_write_operations(self, setup_tortoise_with_router): + """Test write operations with router configuration.""" + # Create (should be routed by router) + user = await User.create(name="Router Write Test", email="routerwrite@example.com") + assert user.id is not None + + # Update (should be routed by router) + user.name = "Updated Router User" + await user.save() + + updated_user = await User.get(id=user.id) + assert updated_user.name == "Updated Router User" + + # Delete (should be routed by router) + await updated_user.delete() + + with pytest.raises(Exception): + await User.get(id=user.id) diff --git a/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py new file mode 100644 index 00000000..012e5572 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_custom_endpoint.py @@ -0,0 +1,155 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from time import perf_counter_ns, sleep +from uuid import uuid4 + +import pytest +import pytest_asyncio +from boto3 import client +from botocore.exceptions import ClientError + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_deployments([DatabaseEngineDeployment.AURORA]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseCustomEndpoint: + """Test class for Tortoise ORM with custom endpoint plugin.""" + endpoint_id = f"test-tortoise-endpoint-{uuid4()}" + endpoint_info: dict[str, str] = {} + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def create_custom_endpoint(self, rds_utils): + """Create a custom endpoint for testing.""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + instance_ids = [rds_utils.get_cluster_writer_instance_id()] + + try: + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + self._wait_until_endpoint_available(rds_client) + yield self.endpoint_info["Endpoint"] + finally: + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + self._wait_until_endpoint_deleted(rds_client) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pass # Ignore if endpoint doesn't exist + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + """Wait for the custom endpoint to become available.""" + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + available = False + + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + sleep(3) + continue + + response_endpoint = response_endpoints[0] + TestTortoiseCustomEndpoint.endpoint_info = response_endpoint + available = "available" == response_endpoint["Status"] + if available: + break + + sleep(3) + + if not available: + pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}") + + def _wait_until_endpoint_deleted(self, rds_client): + """Wait for the custom endpoint to be deleted.""" + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + + while perf_counter_ns() < end_ns: + try: + rds_client.describe_db_cluster_endpoints(DBClusterEndpointIdentifier=self.endpoint_id) + sleep(5) # Still exists, keep waiting + except ClientError as e: + if e.response['Error']['Code'] == 'DBClusterEndpointNotFoundFault': + return # Successfully deleted + raise # Other error, re-raise + + @pytest_asyncio.fixture + async def setup_tortoise_custom_endpoint(self, conn_utils, create_custom_endpoint, request): + """Setup Tortoise with custom endpoint plugin.""" + plugins, user = request.param + user_value = getattr(conn_utils, user) if user != "default" else None + + kwargs = {} + if "fastest_response_strategy" in plugins: + kwargs["reader_host_selector_strategy"] = "fastest_response" + + async for result in setup_tortoise(conn_utils, plugins=plugins, host=create_custom_endpoint, user=user_value, **kwargs): + yield result + + @pytest.mark.parametrize("setup_tortoise_custom_endpoint", [ + ("custom_endpoint,aurora_connection_tracker", "default"), + ("failover,iam,aurora_connection_tracker,custom_endpoint,fastest_response_strategy", "iam_user") + ], indirect=True) + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_custom_endpoint): + """Test basic read operations with custom endpoint plugin.""" + await run_basic_read_operations("Custom Test", "custom") + + @pytest.mark.parametrize("setup_tortoise_custom_endpoint", [ + ("custom_endpoint,aurora_connection_tracker", "default"), + ("failover,iam,aurora_connection_tracker,custom_endpoint,fastest_response_strategy", "iam_user") + ], indirect=True) + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_custom_endpoint): + """Test basic write operations with custom endpoint plugin.""" + await run_basic_write_operations("Custom", "customwrite") diff --git a/tests/integration/container/tortoise/test_tortoise_failover.py b/tests/integration/container/tortoise/test_tortoise_failover.py new file mode 100644 index 00000000..8e31bba6 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_failover.py @@ -0,0 +1,253 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest +import pytest_asyncio +from tortoise import connections +from tortoise.transactions import in_transaction + +from aws_advanced_python_wrapper.errors import ( + FailoverSuccessError, TransactionResolutionUnknownError) +from tests.integration.container.tortoise.models.test_models import \ + TableWithSleepTrigger +from tests.integration.container.tortoise.test_tortoise_common import \ + setup_tortoise +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments, + enable_on_num_instances) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures +from tests.integration.container.utils.test_utils import get_sleep_trigger_sql + +# Global configuration for failover tests +FAILOVER_SLEEP_TIME = 10 # seconds to wait before triggering failover +CONCURRENT_THREAD_COUNT = 5 # number of concurrent threads/queries to spawn +SLEEP_TRIGGER_TIME = 120 # seconds for the database sleep trigger duration + + +# Shared helper functions for failover tests +async def run_single_insert_with_failover(create_record_func, aurora_utility, name_prefix="Test", value="test_value"): + """Helper to test single insert with failover.""" + insert_exception = None + + async def insert_task(): + nonlocal insert_exception + try: + await create_record_func(name_prefix, value) + except Exception as e: + insert_exception = e + + async def failover_task(): + await asyncio.sleep(FAILOVER_SLEEP_TIME) + await asyncio.to_thread(aurora_utility.failover_cluster_and_wait_until_writer_changed) + + await asyncio.gather(insert_task(), failover_task(), return_exceptions=True) + + assert insert_exception is not None + assert isinstance(insert_exception, FailoverSuccessError) + + +async def run_concurrent_queries_with_failover(aurora_utility, record_name="Concurrent Test", record_value="sleep_value"): + """Helper to test concurrent queries with failover.""" + connection = connections.get("default") + + async def run_select_query(query_id): + return await connection.execute_query(f"SELECT {query_id} as query_id") + + # Run concurrent select queries + initial_tasks = [run_select_query(i) for i in range(CONCURRENT_THREAD_COUNT)] + initial_results = await asyncio.gather(*initial_tasks) + assert len(initial_results) == CONCURRENT_THREAD_COUNT + + # Run insert query with failover + sleep_exception = None + + async def insert_query_task(): + nonlocal sleep_exception + try: + await TableWithSleepTrigger.create(name=record_name, value=record_value) + except Exception as e: + sleep_exception = e + + async def failover_task(): + await asyncio.sleep(FAILOVER_SLEEP_TIME) + + await asyncio.to_thread(aurora_utility.failover_cluster_and_wait_until_writer_changed) + + await asyncio.gather(insert_query_task(), failover_task(), return_exceptions=True) + + assert sleep_exception is not None + assert isinstance(sleep_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) + + # Run another set of concurrent select queries after failover + post_failover_tasks = [run_select_query(i + 100) for i in range(CONCURRENT_THREAD_COUNT)] + post_failover_results = await asyncio.gather(*post_failover_tasks) + assert len(post_failover_results) == CONCURRENT_THREAD_COUNT + + +async def run_multiple_concurrent_inserts_with_failover(aurora_utility, name_prefix="Concurrent Insert", value="insert_value"): + """Helper to test multiple concurrent inserts with failover.""" + insert_exceptions = [] + + async def insert_task(task_id): + try: + await TableWithSleepTrigger.create(name=f"{name_prefix} {task_id}", value=value) + except Exception as e: + insert_exceptions.append(e) + + async def failover_task(): + await asyncio.sleep(FAILOVER_SLEEP_TIME) + await asyncio.to_thread(aurora_utility.failover_cluster_and_wait_until_writer_changed) + + # Create concurrent insert tasks and 1 failover task + tasks = [insert_task(i) for i in range(CONCURRENT_THREAD_COUNT)] + tasks.append(failover_task()) + + await asyncio.gather(*tasks, return_exceptions=True) + + # Verify ALL tasks got FailoverSuccessError or TransactionResolutionUnknownError + assert len(insert_exceptions) == CONCURRENT_THREAD_COUNT, f"Expected {CONCURRENT_THREAD_COUNT} exceptions, got {len(insert_exceptions)}" + failover_errors = [e for e in insert_exceptions if isinstance(e, (FailoverSuccessError, TransactionResolutionUnknownError))] + assert len(failover_errors) == CONCURRENT_THREAD_COUNT, ( + f"Expected all {CONCURRENT_THREAD_COUNT} tasks to get failover errors, got {len(failover_errors)}" + ) + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_num_instances(min_instances=2) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseFailover: + """Test class for Tortoise ORM with Failover.""" + @pytest.fixture(scope='class') + def aurora_utility(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + async def _create_sleep_trigger_record(self, name_prefix="Plugin Test", value="test_value", using_db=None): + await TableWithSleepTrigger.create(name=f"{name_prefix}", value=value, using_db=using_db) + + @pytest_asyncio.fixture + async def sleep_trigger_setup(self): + """Setup and cleanup sleep trigger for testing.""" + connection = connections.get("default") + db_engine = TestEnvironment.get_current().get_engine() + trigger_sql = get_sleep_trigger_sql(db_engine, SLEEP_TRIGGER_TIME, "table_with_sleep_trigger") + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + await connection.execute_query(trigger_sql) + yield + await connection.execute_query("DROP TRIGGER IF EXISTS table_with_sleep_trigger_sleep_trigger") + + @pytest_asyncio.fixture + async def setup_tortoise_with_failover(self, conn_utils, request): + """Setup Tortoise with failover plugins.""" + plugins = request.param + kwargs = { + "topology_refresh_ms": 1000, + "connect_timeout": 15, + "monitoring-connect_timeout": 10, + "use_pure": True, + } + + # Add reader strategy if multiple plugins + if "fastest_response_strategy" in plugins: + kwargs["reader_host_selector_strategy"] = "fastest_response" + user = conn_utils.iam_user + kwargs.pop("use_pure") + else: + user = None + + async for result in setup_tortoise(conn_utils, plugins=plugins, user=user, **kwargs): + yield result + + @pytest.mark.parametrize("setup_tortoise_with_failover", [ + "failover", + "failover,aurora_connection_tracker,fastest_response_strategy,iam" + ], indirect=True) + @pytest.mark.asyncio + async def test_basic_operations_with_failover( + self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test failover when inserting to a single table""" + await run_single_insert_with_failover(self._create_sleep_trigger_record, aurora_utility) + + @pytest.mark.parametrize("setup_tortoise_with_failover", [ + "failover", + "failover,aurora_connection_tracker,fastest_response_strategy,iam" + ], indirect=True) + @pytest.mark.asyncio + async def test_transaction_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test transactions with failover during long-running operations.""" + transaction_exception = None + + async def transaction_task(): + nonlocal transaction_exception + try: + async with in_transaction() as conn: + await self._create_sleep_trigger_record("TX Plugin Test", "tx_test_value", conn) + except Exception as e: + transaction_exception = e + + async def failover_task(): + await asyncio.sleep(FAILOVER_SLEEP_TIME) + await asyncio.to_thread(aurora_utility.failover_cluster_and_wait_until_writer_changed) + + await asyncio.gather(transaction_task(), failover_task(), return_exceptions=True) + + assert transaction_exception is not None + assert isinstance(transaction_exception, (FailoverSuccessError, TransactionResolutionUnknownError)) + + # Verify no records were created due to transaction rollback + record_count = await TableWithSleepTrigger.all().count() + assert record_count == 0 + + # Verify autocommit is re-enabled after failover by inserting a record + autocommit_record = await TableWithSleepTrigger.create( + name="Autocommit Test", + value="autocommit_value" + ) + + # Verify the record exists (should be auto-committed) + found_record = await TableWithSleepTrigger.get(id=autocommit_record.id) + assert found_record.name == "Autocommit Test" + assert found_record.value == "autocommit_value" + + # Clean up the test record + await found_record.delete() + + @pytest.mark.parametrize("setup_tortoise_with_failover", [ + "failover", + "failover,aurora_connection_tracker,fastest_response_strategy,iam" + ], indirect=True) + @pytest.mark.asyncio + async def test_concurrent_queries_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test concurrent queries with failover during long-running operation.""" + await run_concurrent_queries_with_failover(aurora_utility) + + @pytest.mark.parametrize("setup_tortoise_with_failover", [ + "failover", + "failover,aurora_connection_tracker,fastest_response_strategy,iam" + ], indirect=True) + @pytest.mark.asyncio + async def test_multiple_concurrent_inserts_with_failover(self, setup_tortoise_with_failover, sleep_trigger_setup, aurora_utility): + """Test multiple concurrent insert operations with failover during long-running operations.""" + await run_multiple_concurrent_inserts_with_failover(aurora_utility) diff --git a/tests/integration/container/tortoise/test_tortoise_iam_authentication.py b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py new file mode 100644 index 00000000..b4f9e32d --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_iam_authentication.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import (disable_on_engines, + disable_on_features, + enable_on_features) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_features([TestEnvironmentFeatures.IAM]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseIamAuthentication: + """Test class for Tortoise ORM with IAM authentication.""" + + @pytest_asyncio.fixture + async def setup_tortoise_iam(self, conn_utils): + """Setup Tortoise with IAM authentication.""" + async for result in setup_tortoise(conn_utils, plugins="iam", user=conn_utils.iam_user): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_iam): + """Test basic read operations with IAM authentication.""" + await run_basic_read_operations() + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_iam): + """Test basic write operations with IAM authentication.""" + await run_basic_write_operations() diff --git a/tests/integration/container/tortoise/test_tortoise_relations.py b/tests/integration/container/tortoise/test_tortoise_relations.py new file mode 100644 index 00000000..28a88e64 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_relations.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +from tortoise import Tortoise +from tortoise.exceptions import IntegrityError + +# Import to register the aws-mysql backend +import aws_advanced_python_wrapper.tortoise_orm # noqa: F401 +from tests.integration.container.tortoise.models.test_models_relationships import ( + RelTestAccount, RelTestAccountProfile, RelTestLearner, RelTestPublication, + RelTestPublisher, RelTestSubject) +from tests.integration.container.tortoise.test_tortoise_common import \ + reset_tortoise +from tests.integration.container.utils.conditions import disable_on_engines +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.test_environment import TestEnvironment + + +async def clear_relationship_models(): + """Clear all relationship test models.""" + await RelTestSubject.all().delete() + await RelTestLearner.all().delete() + await RelTestPublication.all().delete() + await RelTestPublisher.all().delete() + await RelTestAccountProfile.all().delete() + await RelTestAccount.all().delete() + + +@disable_on_engines([DatabaseEngine.PG]) +class TestTortoiseRelationships: + + @pytest_asyncio.fixture(scope="function") + async def setup_tortoise_relationships(self, conn_utils): + """Setup Tortoise with relationship models.""" + db_url = conn_utils.get_aws_tortoise_url( + TestEnvironment.get_current().get_engine(), + plugins="aurora_connection_tracker", + ) + + config = { + "connections": { + "default": db_url + }, + "apps": { + "models": { + "models": ["tests.integration.container.tortoise.models.test_models_relationships"], + "default_connection": "default", + } + } + } + + await Tortoise.init(config=config) + await Tortoise.generate_schemas() + + # Clear any existing data + await clear_relationship_models() + + yield + + # Cleanup + await clear_relationship_models() + await reset_tortoise() + + @pytest.mark.asyncio + async def test_one_to_one_relationship_creation(self, setup_tortoise_relationships): + """Test creating one-to-one relationships""" + account = await RelTestAccount.create(username="testuser", email="test@example.com") + profile = await RelTestAccountProfile.create(account=account, bio="Test bio", avatar_url="http://example.com/avatar.jpg") + + # Test forward relationship + assert profile.account_id == account.id + + # Test reverse relationship + account_with_profile = await RelTestAccount.get(id=account.id).prefetch_related("profile") + assert account_with_profile.profile.bio == "Test bio" + + @pytest.mark.asyncio + async def test_one_to_one_cascade_delete(self, setup_tortoise_relationships): + """Test cascade delete in one-to-one relationship""" + account = await RelTestAccount.create(username="deleteuser", email="delete@example.com") + await RelTestAccountProfile.create(account=account, bio="Will be deleted") + + # Delete account should cascade to profile + await account.delete() + + # Profile should be deleted + profile_count = await RelTestAccountProfile.filter(account_id=account.id).count() + assert profile_count == 0 + + @pytest.mark.asyncio + async def test_one_to_one_unique_constraint(self, setup_tortoise_relationships): + """Test unique constraint in one-to-one relationship""" + account = await RelTestAccount.create(username="uniqueuser", email="unique@example.com") + await RelTestAccountProfile.create(account=account, bio="First profile") + + # Creating another profile for same account should fail + with pytest.raises(IntegrityError): + await RelTestAccountProfile.create(account=account, bio="Second profile") + + @pytest.mark.asyncio + async def test_one_to_many_relationship_creation(self, setup_tortoise_relationships): + """Test creating one-to-many relationships""" + publisher = await RelTestPublisher.create(name="Test Publisher", email="publisher@example.com") + pub1 = await RelTestPublication.create(title="Publication 1", isbn="1234567890123", publisher=publisher) + pub2 = await RelTestPublication.create(title="Publication 2", isbn="1234567890124", publisher=publisher) + + # Test forward relationship + assert pub1.publisher_id == publisher.id + assert pub2.publisher_id == publisher.id + + # Test reverse relationship + publisher_with_pubs = await RelTestPublisher.get(id=publisher.id).prefetch_related("publications") + assert len(publisher_with_pubs.publications) == 2 + assert {pub.title for pub in publisher_with_pubs.publications} == {"Publication 1", "Publication 2"} + + @pytest.mark.asyncio + async def test_one_to_many_cascade_delete(self, setup_tortoise_relationships): + """Test cascade delete in one-to-many relationship""" + publisher = await RelTestPublisher.create(name="Delete Publisher", email="deletepub@example.com") + await RelTestPublication.create(title="Delete Pub 1", isbn="9999999999999", publisher=publisher) + await RelTestPublication.create(title="Delete Pub 2", isbn="9999999999998", publisher=publisher) + + # Delete publisher should cascade to publications + await publisher.delete() + + # Publications should be deleted + pub_count = await RelTestPublication.filter(publisher_id=publisher.id).count() + assert pub_count == 0 + + @pytest.mark.asyncio + async def test_foreign_key_constraint(self, setup_tortoise_relationships): + """Test foreign key constraint enforcement""" + # Creating publication with non-existent publisher should fail + with pytest.raises(IntegrityError): + await RelTestPublication.create(title="Orphan Publication", isbn="0000000000000", publisher_id=99999) + + @pytest.mark.asyncio + async def test_many_to_many_relationship_creation(self, setup_tortoise_relationships): + """Test creating many-to-many relationships""" + learner1 = await RelTestLearner.create(name="Learner 1", learner_id="L001") + learner2 = await RelTestLearner.create(name="Learner 2", learner_id="L002") + subject1 = await RelTestSubject.create(name="Math 101", code="MATH101", credits=3) + subject2 = await RelTestSubject.create(name="Physics 101", code="PHYS101", credits=4) + + # Add learners to subjects + await subject1.learners.add(learner1, learner2) + await subject2.learners.add(learner1) + + # Test forward relationship + subject1_with_learners = await RelTestSubject.get(id=subject1.id).prefetch_related("learners") + assert len(subject1_with_learners.learners) == 2 + assert {learner.name for learner in subject1_with_learners.learners} == {"Learner 1", "Learner 2"} + + # Test reverse relationship + learner1_with_subjects = await RelTestLearner.get(id=learner1.id).prefetch_related("subjects") + assert len(learner1_with_subjects.subjects) == 2 + assert {subject.name for subject in learner1_with_subjects.subjects} == {"Math 101", "Physics 101"} + + @pytest.mark.asyncio + async def test_many_to_many_remove_relationship(self, setup_tortoise_relationships): + """Test removing many-to-many relationships""" + learner = await RelTestLearner.create(name="Remove Learner", learner_id="L003") + subject = await RelTestSubject.create(name="Remove Subject", code="REM101", credits=2) + + # Add relationship + await subject.learners.add(learner) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 1 + + # Remove relationship + await subject.learners.remove(learner) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 0 + + @pytest.mark.asyncio + async def test_many_to_many_clear_relationships(self, setup_tortoise_relationships): + """Test clearing all many-to-many relationships""" + learner1 = await RelTestLearner.create(name="Clear Learner 1", learner_id="L004") + learner2 = await RelTestLearner.create(name="Clear Learner 2", learner_id="L005") + subject = await RelTestSubject.create(name="Clear Subject", code="CLR101", credits=1) + + # Add multiple relationships + await subject.learners.add(learner1, learner2) + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 2 + + # Clear all relationships + await subject.learners.clear() + subject_with_learners = await RelTestSubject.get(id=subject.id).prefetch_related("learners") + assert len(subject_with_learners.learners) == 0 + + @pytest.mark.asyncio + async def test_unique_constraints(self, setup_tortoise_relationships): + """Test unique constraints on model fields""" + # Test account unique username + await RelTestAccount.create(username="unique1", email="unique1@example.com") + with pytest.raises(IntegrityError): + await RelTestAccount.create(username="unique1", email="different@example.com") + + # Test publication unique ISBN + publisher = await RelTestPublisher.create(name="ISBN Publisher", email="isbn@example.com") + await RelTestPublication.create(title="ISBN Pub 1", isbn="1111111111111", publisher=publisher) + with pytest.raises(IntegrityError): + await RelTestPublication.create(title="ISBN Pub 2", isbn="1111111111111", publisher=publisher) + + # Test subject unique code + await RelTestSubject.create(name="Unique Subject 1", code="UNQ101", credits=3) + with pytest.raises(IntegrityError): + await RelTestSubject.create(name="Unique Subject 2", code="UNQ101", credits=4) diff --git a/tests/integration/container/tortoise/test_tortoise_secrets_manager.py b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py new file mode 100644 index 00000000..c79d39f8 --- /dev/null +++ b/tests/integration/container/tortoise/test_tortoise_secrets_manager.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from uuid import uuid4 + +import boto3 +import pytest +import pytest_asyncio + +from tests.integration.container.tortoise.test_tortoise_common import ( + run_basic_read_operations, run_basic_write_operations, setup_tortoise) +from tests.integration.container.utils.conditions import ( + disable_on_engines, disable_on_features, enable_on_deployments) +from tests.integration.container.utils.database_engine import DatabaseEngine +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@disable_on_engines([DatabaseEngine.PG]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestTortoiseSecretsManager: + """Test class for Tortoise ORM with AWS Secrets Manager authentication.""" + + @pytest.fixture(scope='class') + def create_secret(self, conn_utils): + """Create a secret in AWS Secrets Manager with database credentials.""" + region = TestEnvironment.get_current().get_info().get_region() + client = boto3.client('secretsmanager', region_name=region) + + secret_name = f"test-tortoise-secret-{uuid4()}" + secret_value = { + "username": conn_utils.user, + "password": conn_utils.password + } + + try: + response = client.create_secret( + Name=secret_name, + SecretString=json.dumps(secret_value) + ) + secret_id = response['ARN'] + yield secret_id + finally: + try: + client.delete_secret( + SecretId=secret_name, + ForceDeleteWithoutRecovery=True + ) + except Exception: + pass + + @pytest_asyncio.fixture + async def setup_tortoise_secrets_manager(self, conn_utils, create_secret): + """Setup Tortoise with Secrets Manager authentication.""" + async for result in setup_tortoise(conn_utils, + plugins="aws_secrets_manager", + secrets_manager_secret_id=create_secret): + yield result + + @pytest.mark.asyncio + async def test_basic_read_operations(self, setup_tortoise_secrets_manager): + """Test basic read operations with Secrets Manager authentication.""" + await run_basic_read_operations("Secrets", "secrets") + + @pytest.mark.asyncio + async def test_basic_write_operations(self, setup_tortoise_secrets_manager): + """Test basic write operations with Secrets Manager authentication.""" + await run_basic_write_operations("Secrets", "secretswrite") diff --git a/tests/integration/container/utils/conditions.py b/tests/integration/container/utils/conditions.py index c6c96041..604e5f7b 100644 --- a/tests/integration/container/utils/conditions.py +++ b/tests/integration/container/utils/conditions.py @@ -87,3 +87,11 @@ def disable_on_features(disable_on_test_features: List[TestEnvironmentFeatures]) disable_test, reason="The current test environment contains test features for which this test is disabled" ) + + +def disable_on_deployments(requested_deployments: List[DatabaseEngineDeployment]): + current_deployment = TestEnvironment.get_current().get_deployment() + return pytest.mark.skipif( + current_deployment in requested_deployments, + reason=f"This test is not supported for {current_deployment.value} deployments" + ) diff --git a/tests/integration/container/utils/connection_utils.py b/tests/integration/container/utils/connection_utils.py index 52d0add5..e85a7a0a 100644 --- a/tests/integration/container/utils/connection_utils.py +++ b/tests/integration/container/utils/connection_utils.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Optional +from .database_engine import DatabaseEngine from .driver_helper import DriverHelper from .test_environment import TestEnvironment @@ -97,3 +98,31 @@ def get_proxy_connect_params( password = self.password if password is None else password dbname = self.dbname if dbname is None else dbname return DriverHelper.get_connect_params(host, port, user, password, dbname) + + def get_aws_tortoise_url( + self, + db_engine: DatabaseEngine, + host: Optional[str] = None, + port: Optional[int] = None, + user: Optional[str] = None, + password: Optional[str] = None, + dbname: Optional[str] = None, + **kwargs) -> str: + """Build AWS MySQL connection URL for Tortoise ORM with query parameters.""" + env_host = self.writer_cluster_host if self.writer_cluster_host else self.writer_host + host = env_host if host is None else host + port = self.port if port is None else port + user = self.user if user is None else user + password = self.password if password is None else password + dbname = self.dbname if dbname is None else dbname + + # Build base URL + protocol = "aws-pg" if db_engine == DatabaseEngine.PG else "aws-mysql" + url = f"{protocol}://{user}:{password}@{host}:{port}/{dbname}" + + # Add all kwargs as query parameters + if kwargs: + params = [f"{key}={value}" for key, value in kwargs.items()] + url += "?" + "&".join(params) + + return url diff --git a/tests/integration/container/utils/test_telemetry_info.py b/tests/integration/container/utils/test_telemetry_info.py index 24267dc8..e37941ef 100644 --- a/tests/integration/container/utils/test_telemetry_info.py +++ b/tests/integration/container/utils/test_telemetry_info.py @@ -11,18 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + import typing from typing import Any, Dict diff --git a/tests/integration/container/utils/test_utils.py b/tests/integration/container/utils/test_utils.py new file mode 100644 index 00000000..a8942c0b --- /dev/null +++ b/tests/integration/container/utils/test_utils.py @@ -0,0 +1,52 @@ +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .database_engine import DatabaseEngine + + +def get_sleep_sql(db_engine: DatabaseEngine, seconds: int): + if db_engine == DatabaseEngine.MYSQL: + return f"SELECT SLEEP({seconds});" + elif db_engine == DatabaseEngine.PG: + return f"SELECT PG_SLEEP({seconds});" + else: + raise ValueError("Unknown database engine: " + str(db_engine)) + + +def get_sleep_trigger_sql(db_engine: DatabaseEngine, duration: int, table_name: str) -> str: + """Generate SQL to create a sleep trigger for INSERT operations.""" + if db_engine == DatabaseEngine.MYSQL: + return f""" + CREATE TRIGGER {table_name}_sleep_trigger + BEFORE INSERT ON {table_name} + FOR EACH ROW + BEGIN + DO SLEEP({duration}); + END + """ + elif db_engine == DatabaseEngine.PG: + return f""" + CREATE OR REPLACE FUNCTION {table_name}_sleep_function() + RETURNS TRIGGER AS $$ + BEGIN + PERFORM pg_sleep({duration}); + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + + CREATE TRIGGER {table_name}_sleep_trigger + BEFORE INSERT ON {table_name} + FOR EACH ROW + EXECUTE FUNCTION {table_name}_sleep_function(); + """ + else: + raise ValueError(f"Unknown database engine: {db_engine}") diff --git a/tests/unit/test_async_connection_pool.py b/tests/unit/test_async_connection_pool.py new file mode 100644 index 00000000..eab12745 --- /dev/null +++ b/tests/unit/test_async_connection_pool.py @@ -0,0 +1,440 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time + +import pytest + +from aws_advanced_python_wrapper.errors import (ConnectionReleasedError, + PoolClosingError, + PoolExhaustedError, + PoolNotInitializedError) +from aws_advanced_python_wrapper.tortoise_orm.async_support.async_connection_pool import ( + AsyncConnectionPool, ConnectionState, PoolConfig) + + +class MockConnection: + """Mock connection for testing""" + + def __init__(self, connection_id: int = 1): + self.connection_id = connection_id + self.closed = False + self.is_closed = False + + async def close(self): + self.closed = True + + def some_method(self): + return "mock_result" + + +@pytest.fixture +def mock_creator(): + """Mock connection creator""" + counter = 0 + + async def creator(): + nonlocal counter + counter += 1 + return MockConnection(counter) + return creator + + +@pytest.fixture +def mock_health_check(): + """Mock health check function""" + async def health_check(conn): + if hasattr(conn, 'healthy') and not conn.healthy: + raise Exception("Connection unhealthy") + return health_check + + +@pytest.fixture +def pool_config(): + """Basic pool configuration""" + return PoolConfig( + min_size=1, + max_size=3, + acquire_conn_timeout=1.0, + max_conn_lifetime=10.0, + max_conn_idle_time=5.0, + health_check_interval=0.1 + ) + + +class TestAsyncConnectionPool: + """Test cases for AsyncConnectionPool""" + + @pytest.mark.asyncio + async def test_pool_initialization(self, mock_creator, pool_config): + """Test pool initialization creates minimum connections""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + + await pool.initialize() + + stats = pool.get_stats() + assert stats["initialized"] is True + assert stats["total_size"] >= 1 + assert stats["available_in_queue"] >= 0 + + await pool.close() + + @pytest.mark.asyncio + async def test_acquire_and_release(self, mock_creator, pool_config): + """Test basic acquire and release operations""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + # Acquire connection + conn = await pool.acquire() + assert conn.connection_id == 1 + assert conn.state == ConnectionState.IN_USE + + stats = pool.get_stats() + assert stats["in_use"] >= 1 + assert stats["available_in_queue"] == 0 + + # Release connection + await conn.release() + + stats = pool.get_stats() + assert stats["in_use"] == 0 + assert stats["available_in_queue"] >= 1 + + await pool.close() + + @pytest.mark.asyncio + async def test_context_manager(self, mock_creator, pool_config): + """Test context manager automatically releases connections""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + async with pool.connection() as conn: + assert conn.state == ConnectionState.IN_USE + + stats = pool.get_stats() + assert stats["in_use"] == 0 + assert stats["available_in_queue"] >= 1 + + await pool.close() + + @pytest.mark.asyncio + async def test_pool_expansion(self, mock_creator, pool_config): + """Test pool creates new connections when needed""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + # Acquire all connections + conn1 = await pool.acquire() + conn2 = await pool.acquire() + conn3 = await pool.acquire() + + stats = pool.get_stats() + assert stats["total_size"] == 3 + assert stats["in_use"] == 3 + + await conn1.release() + await conn2.release() + await conn3.release() + await pool.close() + + @pytest.mark.asyncio + async def test_pool_exhaustion(self, mock_creator): + """Test pool exhaustion raises PoolExhaustedError""" + config = PoolConfig(min_size=1, max_size=1, overflow=0, acquire_conn_timeout=0.1) + pool = AsyncConnectionPool(mock_creator, config=config) + await pool.initialize() + + # Acquire the only connection + conn1 = await pool.acquire() + + # Try to acquire another - should timeout + with pytest.raises(PoolExhaustedError): + await pool.acquire() + + await conn1.release() + await pool.close() + + @pytest.mark.asyncio + async def test_connection_validation_failure(self, mock_creator, mock_health_check, pool_config): + """Test connection validation and recreation""" + pool = AsyncConnectionPool(mock_creator, health_check=mock_health_check, config=pool_config) + await pool.initialize() + + # Get connection and mark it unhealthy + conn = await pool.acquire() + conn.connection.healthy = False + await conn.release() + + # Next acquire should recreate connection due to failed health check + new_conn = await pool.acquire() + assert new_conn.connection_id == 2 # New connection created + + await new_conn.release() + await pool.close() + + @pytest.mark.asyncio + async def test_stale_connection_detection(self, mock_creator, pool_config): + """Test stale connection detection""" + config = PoolConfig(min_size=1, max_size=3, max_conn_lifetime=0.1, max_conn_idle_time=0.1) + pool = AsyncConnectionPool(mock_creator, config=config) + await pool.initialize() + + conn = await pool.acquire() + + # Make connection stale + conn.created_at = time.monotonic() - 1.0 + + assert conn.is_stale(0.1, 0.1) is True + + await conn.release() + await pool.close() + + @pytest.mark.asyncio + async def test_connection_proxy_methods(self, mock_creator, pool_config): + """Test connection proxies methods to underlying connection""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + conn = await pool.acquire() + + # Test method proxying + result = conn.some_method() + assert result == "mock_result" + + await conn.release() + await pool.close() + + @pytest.mark.asyncio + async def test_released_connection_access(self, mock_creator, pool_config): + """Test accessing released connection raises error""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + conn = await pool.acquire() + await conn.release() + + # Accessing released connection should raise error + with pytest.raises(ConnectionReleasedError): + _ = conn.some_method() + + await pool.close() + + @pytest.mark.asyncio + async def test_pool_not_initialized_error(self, mock_creator, pool_config): + """Test acquiring from uninitialized pool raises error""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + + with pytest.raises(PoolNotInitializedError): + await pool.acquire() + + @pytest.mark.asyncio + async def test_pool_closing_error(self, mock_creator, pool_config): + """Test acquiring from closing pool raises error""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + # Start closing + pool._closing = True + + with pytest.raises(PoolClosingError): + await pool.acquire() + + await pool.close() + + @pytest.mark.asyncio + async def test_pool_close_cleanup(self, mock_creator, pool_config): + """Test pool close cleans up all connections""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + # Acquire some connections + conn1 = await pool.acquire() + conn2 = await pool.acquire() + + await pool.close() + + # Connections should be closed + assert conn1.connection.closed is True + assert conn2.connection.closed is True + + stats = pool.get_stats() + assert stats["closing"] is True + + @pytest.mark.asyncio + async def test_maintenance_loop_creates_connections(self, mock_creator): + """Test maintenance loop creates connections when below minimum""" + config = PoolConfig(min_size=2, max_size=5, health_check_interval=0.05) + pool = AsyncConnectionPool(mock_creator, config=config) + await pool.initialize() + + # Remove a connection to go below minimum + conn = await pool.acquire() + await pool._close_connection(conn) + + # Wait for maintenance loop + await asyncio.sleep(0.1) + + stats = pool.get_stats() + assert stats["total_size"] >= 2 # Should restore minimum + + await pool.close() + + @pytest.mark.asyncio + async def test_overflow_connections(self, mock_creator): + """Test overflow connections beyond max_size""" + config = PoolConfig(min_size=1, max_size=2, overflow=1, acquire_conn_timeout=0.1) + pool = AsyncConnectionPool(mock_creator, config=config) + await pool.initialize() + + # Acquire max + overflow connections + conn1 = await pool.acquire() + conn2 = await pool.acquire() + conn3 = await pool.acquire() # Overflow connection + + stats = pool.get_stats() + assert stats["total_size"] == 3 + + await conn1.release() + await conn2.release() + await conn3.release() + + # After releasing all connections, should only keep max_size (2) connections + stats = pool.get_stats() + assert stats["total_size"] == 2 + + await pool.close() + + @pytest.mark.asyncio + async def test_connection_creator_failure(self, pool_config): + """Test handling of connection creation failures""" + async def failing_creator(): + raise Exception("Connection creation failed") + + pool = AsyncConnectionPool(failing_creator, config=pool_config) + + with pytest.raises(Exception): + await pool.initialize() + + @pytest.mark.asyncio + async def test_health_check_timeout(self, mock_creator): + """Test health check with timeout""" + async def slow_health_check(conn): + await asyncio.sleep(1) + + config = PoolConfig(min_size=1, max_size=3, health_check_timeout=0.1) + pool = AsyncConnectionPool(mock_creator, health_check=slow_health_check, config=config) + await pool.initialize() + + conn = await pool.acquire() + await conn.release() + + # Next acquire should recreate due to health check timeout + new_conn = await pool.acquire() + # Should have created at least one new connection due to health check failure + assert new_conn.connection_id > 1 + + await new_conn.release() + await pool.close() + + @pytest.mark.asyncio + async def test_concurrent_acquire_release(self, mock_creator, pool_config): + """Test concurrent acquire and release operations""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + async def acquire_release_task(): + conn = await pool.acquire() + await asyncio.sleep(0.01) # Simulate work + await conn.release() + + # Run multiple concurrent tasks + tasks = [acquire_release_task() for _ in range(10)] + await asyncio.gather(*tasks) + + stats = pool.get_stats() + assert stats["in_use"] == 0 # All connections should be released + + await pool.close() + + @pytest.mark.asyncio + async def test_double_release_ignored(self, mock_creator, pool_config): + """Test double release is safely ignored""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + conn = await pool.acquire() + await conn.release() + await conn.release() # Should not raise error + + await pool.close() + + def test_get_stats_format(self, mock_creator, pool_config): + """Test get_stats returns correct format""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + + stats = pool.get_stats() + + expected_keys = { + "total_size", "idle", "in_use", "available_in_queue", + "max_size", "overflow", "min_size", "initialized", "closing" + } + assert set(stats.keys()) == expected_keys + assert stats["initialized"] is False + assert stats["closing"] is False + + @pytest.mark.asyncio + async def test_maintenance_loop_removes_stale_connections(self, mock_creator): + """Test maintenance loop removes stale idle connections""" + config = PoolConfig(min_size=2, max_size=5, max_conn_idle_time=0.1, health_check_interval=0.05) + pool = AsyncConnectionPool(mock_creator, config=config) + await pool.initialize() + + # Acquire and release a connection to make it idle + conn = await pool.acquire() + await conn.release() + + # Make the connection stale by backdating its last_used time + async with pool._lock: + for pooled_conn in pool._all_connections.values(): + if pooled_conn.state == ConnectionState.IDLE: + pooled_conn.last_used = time.monotonic() - 1.0 # Make it stale + + # Wait for maintenance loop to run + await asyncio.sleep(0.15) + + # Should have removed stale connection and created new one to maintain min_size + stats = pool.get_stats() + assert stats["total_size"] == config.min_size + + await pool.close() + + @pytest.mark.asyncio + async def test_default_closer(self, mock_creator, pool_config): + """Test default closer handles both sync and async close methods""" + pool = AsyncConnectionPool(mock_creator, config=pool_config) + await pool.initialize() + + conn = await pool.acquire() + connection_id = conn.connection_id + + # Close the connection (should use default closer) + await pool._close_connection(conn) + + # Connection should be closed and removed from tracking + assert conn.connection.closed is True + assert connection_id not in pool._all_connections + + await pool.close() diff --git a/tests/unit/test_async_wrapper.py b/tests/unit/test_async_wrapper.py new file mode 100644 index 00000000..17d72c96 --- /dev/null +++ b/tests/unit/test_async_wrapper.py @@ -0,0 +1,206 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest + +from aws_advanced_python_wrapper.tortoise_orm.async_support.async_wrapper import ( + AwsConnectionAsyncWrapper, AwsCursorAsyncWrapper, AwsWrapperAsyncConnector) + + +class TestAwsCursorAsyncWrapper: + def test_init(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + assert wrapper._cursor == mock_cursor + + @pytest.mark.asyncio + async def test_execute(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "result" + + result = await wrapper.execute("SELECT 1", ["param"]) + + mock_to_thread.assert_called_once_with(mock_cursor.execute, "SELECT 1", ["param"]) + assert result == "result" + + @pytest.mark.asyncio + async def test_executemany(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "result" + + result = await wrapper.executemany("INSERT", [["param1"], ["param2"]]) + + mock_to_thread.assert_called_once_with(mock_cursor.executemany, "INSERT", [["param1"], ["param2"]]) + assert result == "result" + + @pytest.mark.asyncio + async def test_fetchall(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = [("row1",), ("row2",)] + + result = await wrapper.fetchall() + + mock_to_thread.assert_called_once_with(mock_cursor.fetchall) + assert result == [("row1",), ("row2",)] + + @pytest.mark.asyncio + async def test_fetchone(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = ("row1",) + + result = await wrapper.fetchone() + + mock_to_thread.assert_called_once_with(mock_cursor.fetchone) + assert result == ("row1",) + + @pytest.mark.asyncio + async def test_close(self): + mock_cursor = MagicMock() + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await wrapper.close() + mock_to_thread.assert_called_once_with(mock_cursor.close) + + def test_getattr(self): + mock_cursor = MagicMock() + mock_cursor.rowcount = 5 + wrapper = AwsCursorAsyncWrapper(mock_cursor) + + assert wrapper.rowcount == 5 + + +class TestAwsConnectionAsyncWrapper: + def test_init(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + assert wrapper._wrapped_connection == mock_connection + + @pytest.mark.asyncio + async def test_cursor_context_manager(self): + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_connection.cursor.return_value = mock_cursor + + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.side_effect = [mock_cursor, None] + + async with wrapper.cursor() as cursor: + assert isinstance(cursor, AwsCursorAsyncWrapper) + assert cursor._cursor == mock_cursor + + assert mock_to_thread.call_count == 2 + + @pytest.mark.asyncio + async def test_rollback(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "rollback_result" + + result = await wrapper.rollback() + + mock_to_thread.assert_called_once_with(mock_connection.rollback) + assert result == "rollback_result" + + @pytest.mark.asyncio + async def test_commit(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = "commit_result" + + result = await wrapper.commit() + + mock_to_thread.assert_called_once_with(mock_connection.commit) + assert result == "commit_result" + + @pytest.mark.asyncio + async def test_set_autocommit(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await wrapper.set_autocommit(True) + + mock_to_thread.assert_called_once_with(setattr, mock_connection, 'autocommit', True) + + @pytest.mark.asyncio + async def test_close(self): + mock_connection = MagicMock() + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await wrapper.close() + + mock_to_thread.assert_called_once_with(mock_connection.close) + + def test_getattr(self): + mock_connection = MagicMock() + mock_connection.some_attr = "test_value" + wrapper = AwsConnectionAsyncWrapper(mock_connection) + + assert wrapper.some_attr == "test_value" + + +class TestAwsWrapperAsyncConnector: + @pytest.mark.asyncio + async def test_connect_with_aws_wrapper(self): + mock_connect_func = MagicMock() + mock_connection = MagicMock() + kwargs = {"host": "localhost", "user": "test"} + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = mock_connection + + result = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(mock_connect_func, **kwargs) + + mock_to_thread.assert_called_once() + assert isinstance(result, AwsConnectionAsyncWrapper) + assert result._wrapped_connection == mock_connection + + @pytest.mark.asyncio + async def test_close_aws_wrapper(self): + mock_connection = MagicMock() + + with patch('asyncio.to_thread') as mock_to_thread: + mock_to_thread.return_value = None + + await AwsWrapperAsyncConnector.close_aws_wrapper(mock_connection) + + mock_to_thread.assert_called_once_with(mock_connection.close) diff --git a/tests/unit/test_fastest_response_strategy_plugin.py b/tests/unit/test_fastest_response_strategy_plugin.py index 733e1dfa..23d59d49 100644 --- a/tests/unit/test_fastest_response_strategy_plugin.py +++ b/tests/unit/test_fastest_response_strategy_plugin.py @@ -25,12 +25,12 @@ @pytest.fixture def writer_host(): - return HostInfo("instance-0", 5432, HostRole.WRITER, HostAvailability.AVAILABLE) + return HostInfo("instance-0", 5432, HostRole.WRITER, HostAvailability.AVAILABLE) @pytest.fixture def reader_host1() -> HostInfo: - return HostInfo("instance-1", 5432, HostRole.READER, HostAvailability.AVAILABLE) + return HostInfo("instance-1", 5432, HostRole.READER, HostAvailability.AVAILABLE) @pytest.fixture diff --git a/tests/unit/test_monitor_service.py b/tests/unit/test_monitor_service.py index ed8e743b..ebea6736 100644 --- a/tests/unit/test_monitor_service.py +++ b/tests/unit/test_monitor_service.py @@ -138,7 +138,7 @@ def test_start_monitoring__errors(monitor_service_mocked_container, mock_conn, m def test_stop_monitoring(monitor_service_with_container, mock_monitor, mock_conn): aliases = frozenset({"instance-1"}) context = monitor_service_with_container.start_monitoring( - mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) + mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) monitor_service_with_container.stop_monitoring(context) mock_monitor.stop_monitoring.assert_called_once_with(context) @@ -146,7 +146,7 @@ def test_stop_monitoring(monitor_service_with_container, mock_monitor, mock_conn def test_stop_monitoring__multiple_calls(monitor_service_with_container, mock_monitor, mock_conn): aliases = frozenset({"instance-1"}) context = monitor_service_with_container.start_monitoring( - mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) + mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) monitor_service_with_container.stop_monitoring(context) mock_monitor.stop_monitoring.assert_called_once_with(context) monitor_service_with_container.stop_monitoring(context) diff --git a/tests/unit/test_tortoise_base_client.py b/tests/unit/test_tortoise_base_client.py new file mode 100644 index 00000000..6db65775 --- /dev/null +++ b/tests/unit/test_tortoise_base_client.py @@ -0,0 +1,142 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from aws_advanced_python_wrapper.tortoise_orm.backends.base.client import ( + TortoiseAwsClientPooledConnectionWrapper, + TortoiseAwsClientPooledTransactionContext) + + +class TestTortoiseAwsClientPooledConnectionWrapper: + def test_init(self): + mock_client = MagicMock() + mock_pool_lock = MagicMock() + + wrapper = TortoiseAwsClientPooledConnectionWrapper(mock_client, mock_pool_lock) + + assert wrapper.client == mock_client + assert wrapper._pool_init_lock == mock_pool_lock + assert wrapper.connection is None + + @pytest.mark.asyncio + async def test_ensure_connection(self): + mock_client = MagicMock() + mock_client._pool = None + mock_client.create_connection = AsyncMock() + mock_pool_lock = AsyncMock() + + wrapper = TortoiseAwsClientPooledConnectionWrapper(mock_client, mock_pool_lock) + + await wrapper.ensure_connection() + mock_client.create_connection.assert_called_once_with(with_db=True) + + @pytest.mark.asyncio + async def test_context_manager(self): + mock_client = MagicMock() + mock_client._pool = MagicMock() + mock_pool_connection = MagicMock() + mock_pool_connection.release = AsyncMock() + mock_client._pool.acquire = AsyncMock(return_value=mock_pool_connection) + mock_pool_lock = AsyncMock() + + wrapper = TortoiseAwsClientPooledConnectionWrapper(mock_client, mock_pool_lock) + + async with wrapper as conn: + assert conn == mock_pool_connection + assert wrapper.connection == mock_pool_connection + + # Verify release was called on exit + mock_pool_connection.release.assert_called_once() + + +class TestTortoiseAwsClientPooledTransactionContext: + def test_init(self): + mock_client = MagicMock() + mock_client.connection_name = "test_conn" + mock_pool_lock = MagicMock() + + context = TortoiseAwsClientPooledTransactionContext(mock_client, mock_pool_lock) + + assert context.client == mock_client + assert context.connection_name == "test_conn" + assert context._pool_init_lock == mock_pool_lock + + @pytest.mark.asyncio + async def test_ensure_connection(self): + mock_client = MagicMock() + mock_client._parent._pool = None + mock_client._parent.create_connection = AsyncMock() + mock_pool_lock = AsyncMock() + + context = TortoiseAwsClientPooledTransactionContext(mock_client, mock_pool_lock) + + await context.ensure_connection() + mock_client._parent.create_connection.assert_called_once_with(with_db=True) + + @pytest.mark.asyncio + async def test_context_manager_commit(self): + mock_client = MagicMock() + mock_client._parent._pool = MagicMock() + mock_client.connection_name = "test_conn" + mock_client._finalized = False + mock_client.begin = AsyncMock() + mock_client.commit = AsyncMock() + mock_pool_connection = MagicMock() + mock_pool_connection.release = AsyncMock() + mock_client._parent._pool.acquire = AsyncMock(return_value=mock_pool_connection) + mock_pool_lock = AsyncMock() + + context = TortoiseAwsClientPooledTransactionContext(mock_client, mock_pool_lock) + + with patch('tortoise.connection.connections') as mock_connections: + mock_connections.set.return_value = "test_token" + + async with context as client: + assert client == mock_client + assert mock_client._connection == mock_pool_connection + + # Verify commit was called and connection was released + mock_client.commit.assert_called_once() + mock_pool_connection.release.assert_called_once() + + @pytest.mark.asyncio + async def test_context_manager_rollback_on_exception(self): + mock_client = MagicMock() + mock_client._parent._pool = MagicMock() + mock_client.connection_name = "test_conn" + mock_client._finalized = False + mock_client.begin = AsyncMock() + mock_client.rollback = AsyncMock() + mock_pool_connection = MagicMock() + mock_pool_connection.release = AsyncMock() + mock_client._parent._pool.acquire = AsyncMock(return_value=mock_pool_connection) + mock_pool_lock = AsyncMock() + + context = TortoiseAwsClientPooledTransactionContext(mock_client, mock_pool_lock) + + with patch('tortoise.connection.connections') as mock_connections: + mock_connections.set.return_value = "test_token" + + try: + async with context: + raise ValueError("Test exception") + except ValueError: + pass + + # Verify rollback was called and connection was released + mock_client.rollback.assert_called_once() + mock_pool_connection.release.assert_called_once() diff --git a/tests/unit/test_tortoise_mysql_client.py b/tests/unit/test_tortoise_mysql_client.py new file mode 100644 index 00000000..98be046f --- /dev/null +++ b/tests/unit/test_tortoise_mysql_client.py @@ -0,0 +1,510 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mysql.connector import errors +from tortoise.exceptions import (DBConnectionError, IntegrityError, + OperationalError, TransactionManagementError) + +from aws_advanced_python_wrapper.tortoise_orm.backends.mysql.client import ( + AwsMySQLClient, TransactionWrapper, _gen_savepoint_name, + translate_exceptions) + + +class TestTranslateExceptions: + @pytest.mark.asyncio + async def test_translate_operational_error(self): + @translate_exceptions + async def test_func(self): + raise errors.OperationalError("Test error") + + with pytest.raises(OperationalError): + await test_func(None) + + @pytest.mark.asyncio + async def test_translate_integrity_error(self): + @translate_exceptions + async def test_func(self): + raise errors.IntegrityError("Test error") + + with pytest.raises(IntegrityError): + await test_func(None) + + @pytest.mark.asyncio + async def test_no_exception(self): + @translate_exceptions + async def test_func(self): + return "success" + + result = await test_func(None) + assert result == "success" + + +class TestAwsMySQLClient: + def test_init(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="utf8mb4", + storage_engine="InnoDB", + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + assert client.user == "test_user" + assert client.password == "test_pass" + assert client.database == "test_db" + assert client.host == "localhost" + assert client.port == 3306 + assert client.charset == "utf8mb4" + assert client.storage_engine == "InnoDB" + + def test_init_defaults(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + assert client.charset == "utf8mb4" + assert client.storage_engine == "innodb" + + @pytest.mark.asyncio + async def test_create_connection_invalid_charset(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + charset="invalid_charset", + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + with pytest.raises(DBConnectionError, match="Unknown character set"): + await client.create_connection(with_db=True) + + @pytest.mark.asyncio + async def test_create_connection_success(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + with patch('aws_advanced_python_wrapper.tortoise_orm.backends.mysql.client.logger'): + await client.create_connection(with_db=True) + + assert client._template["user"] == "test_user" + assert client._template["database"] == "test_db" + assert client._template["autocommit"] is True + + @pytest.mark.asyncio + async def test_close(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + # close() method now does nothing (AWS wrapper handles cleanup) + await client.close() + # No assertions needed since method is a no-op + + def test_acquire_connection(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + connection_wrapper = client.acquire_connection() + assert connection_wrapper.client == client + + @pytest.mark.asyncio + async def test_execute_insert(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.lastrowid = 123 + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + mock_cursor.execute = AsyncMock() + + with patch('aws_advanced_python_wrapper.tortoise_orm.backends.mysql.client.logger'): + result = await client.execute_insert("INSERT INTO test VALUES (?)", ["value"]) + + assert result == 123 + mock_cursor.execute.assert_called_once_with("INSERT INTO test VALUES (?)", ["value"]) + + @pytest.mark.asyncio + async def test_execute_many_with_transactions(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + storage_engine="innodb", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + mock_connection = MagicMock() + mock_cursor = MagicMock() + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + with patch.object(client, '_execute_many_with_transaction') as mock_execute_many_tx: + mock_execute_many_tx.return_value = None + + with patch('aws_advanced_python_wrapper.tortoise_orm.backends.mysql.client.logger'): + await client.execute_many("INSERT INTO test VALUES (?)", [["val1"], ["val2"]]) + + mock_execute_many_tx.assert_called_once_with( + mock_cursor, mock_connection, "INSERT INTO test VALUES (?)", [["val1"], ["val2"]] + ) + + @pytest.mark.asyncio + async def test_execute_query(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.rowcount = 2 + mock_cursor.description = [("id",), ("name",)] + + with patch.object(client, 'acquire_connection') as mock_acquire: + mock_acquire.return_value.__aenter__ = AsyncMock(return_value=mock_connection) + mock_acquire.return_value.__aexit__ = AsyncMock() + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + mock_cursor.execute = AsyncMock() + mock_cursor.fetchall = AsyncMock(return_value=[(1, "test"), (2, "test2")]) + + with patch('aws_advanced_python_wrapper.tortoise_orm.backends.mysql.client.logger'): + rowcount, results = await client.execute_query("SELECT * FROM test") + + assert rowcount == 2 + assert len(results) == 2 + assert results[0] == {"id": 1, "name": "test"} + + @pytest.mark.asyncio + async def test_execute_query_dict(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + with patch.object(client, 'execute_query') as mock_execute_query: + mock_execute_query.return_value = (2, [{"id": 1}, {"id": 2}]) + + results = await client.execute_query_dict("SELECT * FROM test") + + assert results == [{"id": 1}, {"id": 2}] + + def test_in_transaction(self): + with patch('builtins.print'): # Mock the print statement + client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + client._disable_pool_for_testing() + + transaction_context = client._in_transaction() + assert transaction_context is not None + + +class TestTransactionWrapper: + def test_init(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + + assert wrapper.connection_name == "test_conn" + assert wrapper._parent == parent_client + assert wrapper._finalized is False + assert wrapper._savepoint is None + + @pytest.mark.asyncio + async def test_begin(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.begin() + + mock_connection.set_autocommit.assert_called_once_with(False) + assert wrapper._finalized is False + + @pytest.mark.asyncio + async def test_commit(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.commit = AsyncMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.commit() + + mock_connection.commit.assert_called_once() + mock_connection.set_autocommit.assert_called_once_with(True) + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_commit_already_finalized(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + wrapper._finalized = True + + with pytest.raises(TransactionManagementError, match="Transaction already finalized"): + await wrapper.commit() + + @pytest.mark.asyncio + async def test_rollback(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_connection.rollback = AsyncMock() + mock_connection.set_autocommit = AsyncMock() + wrapper._connection = mock_connection + + await wrapper.rollback() + + mock_connection.rollback.assert_called_once() + mock_connection.set_autocommit.assert_called_once_with(True) + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_savepoint(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute = AsyncMock() + + wrapper._connection = mock_connection + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + await wrapper.savepoint() + + assert wrapper._savepoint is not None + assert wrapper._savepoint.startswith("tortoise_savepoint_") + mock_cursor.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_savepoint_rollback(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + wrapper._savepoint = "test_savepoint" + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute = AsyncMock() + + wrapper._connection = mock_connection + mock_connection.cursor.return_value.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_connection.cursor.return_value.__aexit__ = AsyncMock() + + await wrapper.savepoint_rollback() + + mock_cursor.execute.assert_called_once_with("ROLLBACK TO SAVEPOINT test_savepoint") + assert wrapper._savepoint is None + assert wrapper._finalized is True + + @pytest.mark.asyncio + async def test_savepoint_rollback_no_savepoint(self): + with patch('builtins.print'): # Mock the print statement + parent_client = AwsMySQLClient( + user="test_user", + password="test_pass", + database="test_db", + host="localhost", + port=3306, + connection_name="test_conn", + plugins="aurora_connection_tracker,failover" + ) + parent_client._disable_pool_for_testing() + + wrapper = TransactionWrapper(parent_client) + wrapper._savepoint = None + + with pytest.raises(TransactionManagementError, match="No savepoint to rollback to"): + await wrapper.savepoint_rollback() + + +class TestGenSavepointName: + def test_gen_savepoint_name(self): + name1 = _gen_savepoint_name() + name2 = _gen_savepoint_name() + + assert name1.startswith("tortoise_savepoint_") + assert name2.startswith("tortoise_savepoint_") + assert name1 != name2 # Should be unique diff --git a/tests/test_weight_random_host_selector.py b/tests/unit/test_weight_random_host_selector.py similarity index 100% rename from tests/test_weight_random_host_selector.py rename to tests/unit/test_weight_random_host_selector.py