diff --git a/README.md b/README.md index 5666a2c..8f8c717 100644 --- a/README.md +++ b/README.md @@ -47,3 +47,94 @@ async with async_connection.cursor() as cursor: rows = await cursor.fetchmany(size=5) rows = await cursor.fetchall() ``` + +## Query parameters + +### Standard mode (`pyformat=True`) + +Pass `pyformat=True` to `connect()` to use familiar Python DB-API +parameter syntax. The driver will convert placeholders and infer YDB types +from Python values automatically. + +**Named parameters** — `%(name)s` with a `dict`: + +```python +connection = ydb_dbapi.connect( + host="localhost", port="2136", database="/local", + pyformat=True, +) + +with connection.cursor() as cursor: + cursor.execute( + "SELECT * FROM users WHERE id = %(id)s AND active = %(active)s", + {"id": 42, "active": True}, + ) +``` + +**Positional parameters** — `%s` with a `list` or `tuple`: + +```python +with connection.cursor() as cursor: + cursor.execute( + "INSERT INTO users (id, name, score) VALUES (%s, %s, %s)", + [1, "Alice", 9.8], + ) +``` + +Use `%%` to insert a literal `%` character in the query. + +**Automatic type mapping:** + +| Python type | YDB type | +|--------------------|-------------| +| `bool` | `Bool` | +| `int` | `Int64` | +| `float` | `Double` | +| `str` | `Utf8` | +| `bytes` | `String` | +| `datetime.datetime`| `Timestamp` | +| `datetime.date` | `Date` | +| `datetime.timedelta`| `Interval` | +| `decimal.Decimal` | `Decimal(22, 9)` | +| `None` | `NULL` (passed as-is) | + +**Explicit types with `ydb.TypedValue`:** + +When automatic inference is not suitable (e.g. you need `Int32` instead of +`Int64`, or `Json`), wrap the value in `ydb.TypedValue` — it will be passed +through unchanged: + +```python +import ydb + +with connection.cursor() as cursor: + cursor.execute( + "INSERT INTO events (id, payload) VALUES (%(id)s, %(payload)s)", + { + "id": ydb.TypedValue(99, ydb.PrimitiveType.Int32), + "payload": ydb.TypedValue('{"key": "value"}', ydb.PrimitiveType.Json), + }, + ) +``` + +### Native YDB mode (default, deprecated) + +> **Deprecated.** Native YDB mode is the current default for backwards +> compatibility, but it will be removed in a future release. Migrate to +> `pyformat=True` at your earliest convenience. + +By default (`pyformat=False`) the driver passes the query and parameters +directly to the YDB SDK without any transformation. Use `$name` placeholders +in the query and supply a `dict` with `$`-prefixed keys: + +```python +connection = ydb_dbapi.connect( + host="localhost", port="2136", database="/local", +) + +with connection.cursor() as cursor: + cursor.execute( + "SELECT * FROM users WHERE id = $id", + {"$id": ydb.TypedValue(42, ydb.PrimitiveType.Int64)}, + ) +``` diff --git a/tests/test_convert_parameters.py b/tests/test_convert_parameters.py new file mode 100644 index 0000000..1031f0a --- /dev/null +++ b/tests/test_convert_parameters.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import datetime +import decimal + +import ydb +from ydb_dbapi.utils import convert_query_parameters + + +class TestNamedStyle: + """%(name)s placeholders with a dict.""" + + def test_basic_query_transformation(self): + q, _ = convert_query_parameters( + "SELECT %(id)s FROM t WHERE name = %(name)s", + {"id": 1, "name": "alice"}, + ) + assert q == "SELECT $id FROM t WHERE name = $name" + + def test_keys_prefixed_with_dollar(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": 1}) + assert "$x" in p + assert "x" not in p + + def test_int(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": 42}) + assert p["$x"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64) + + def test_float(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": 3.14}) + assert p["$x"] == ydb.TypedValue(3.14, ydb.PrimitiveType.Double) + + def test_str(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": "hello"}) + assert p["$x"] == ydb.TypedValue("hello", ydb.PrimitiveType.Utf8) + + def test_bytes(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": b"data"}) + assert p["$x"] == ydb.TypedValue(b"data", ydb.PrimitiveType.String) + + def test_bool(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": True}) + assert p["$x"] == ydb.TypedValue(True, ydb.PrimitiveType.Bool) + + def test_bool_not_confused_with_int(self): + # bool is subclass of int — must map to Bool, not Int64 + _, p = convert_query_parameters("SELECT %(x)s", {"x": False}) + assert p["$x"].value_type == ydb.PrimitiveType.Bool + + def test_date(self): + d = datetime.date(2024, 1, 15) + _, p = convert_query_parameters("SELECT %(x)s", {"x": d}) + assert p["$x"] == ydb.TypedValue(d, ydb.PrimitiveType.Date) + + def test_datetime(self): + tz = datetime.timezone.utc + dt = datetime.datetime(2024, 1, 15, 12, 0, 0, tzinfo=tz) + _, p = convert_query_parameters("SELECT %(x)s", {"x": dt}) + assert p["$x"] == ydb.TypedValue(dt, ydb.PrimitiveType.Timestamp) + + def test_datetime_not_confused_with_date(self): + # datetime is subclass of date — must map to Timestamp, not Date + tz = datetime.timezone.utc + dt = datetime.datetime(2024, 6, 1, 0, 0, 0, tzinfo=tz) + _, p = convert_query_parameters("SELECT %(x)s", {"x": dt}) + assert p["$x"].value_type == ydb.PrimitiveType.Timestamp + + def test_timedelta(self): + td = datetime.timedelta(seconds=60) + _, p = convert_query_parameters("SELECT %(x)s", {"x": td}) + assert p["$x"] == ydb.TypedValue(td, ydb.PrimitiveType.Interval) + + def test_decimal(self): + d = decimal.Decimal("3.14") + _, p = convert_query_parameters("SELECT %(x)s", {"x": d}) + assert isinstance(p["$x"], ydb.TypedValue) + assert p["$x"].value == d + + def test_none_passed_as_is(self): + _, p = convert_query_parameters("SELECT %(x)s", {"x": None}) + assert p["$x"] is None + + def test_unknown_type_passed_as_is(self): + obj = object() + _, p = convert_query_parameters("SELECT %(x)s", {"x": obj}) + assert p["$x"] is obj + + def test_multiple_params(self): + q, p = convert_query_parameters( + "INSERT INTO t VALUES (%(a)s, %(b)s, %(c)s)", + {"a": 1, "b": "hi", "c": True}, + ) + assert q == "INSERT INTO t VALUES ($a, $b, $c)" + assert "$a" in p + assert "$b" in p + assert "$c" in p + + def test_percent_percent_escape(self): + q, _ = convert_query_parameters( + "SELECT %% as pct, %(x)s", {"x": 1} + ) + assert q == "SELECT % as pct, $x" + + def test_empty_params(self): + q, p = convert_query_parameters("SELECT 1", {}) + assert q == "SELECT 1" + assert p == {} + + +class TestCustomTypes: + """Pass-through for ydb.TypedValue (explicit type hint).""" + + def test_typed_value_passed_through(self): + tv = ydb.TypedValue(42, ydb.PrimitiveType.Int32) + _, p = convert_query_parameters("SELECT %(x)s", {"x": tv}) + assert p["$x"] is tv + + def test_typed_value_not_double_wrapped(self): + tv = ydb.TypedValue("hello", ydb.PrimitiveType.Utf8) + _, p = convert_query_parameters("SELECT %(x)s", {"x": tv}) + assert isinstance(p["$x"], ydb.TypedValue) + assert p["$x"].value_type == ydb.PrimitiveType.Utf8 + + def test_typed_value_positional(self): + tv = ydb.TypedValue(99, ydb.PrimitiveType.Int32) + _, p = convert_query_parameters("SELECT %s", [tv]) + assert p["$p1"] is tv + + def test_unknown_type_passed_as_is(self): + # No TypedValue, no known type — value goes through unchanged + val = object() + _, p = convert_query_parameters("SELECT %(x)s", {"x": val}) + assert p["$x"] is val + + +class TestPositionalStyle: + """Positional %s placeholders with a list or tuple.""" + + def test_basic_list(self): + q, p = convert_query_parameters("SELECT %s", [42]) + assert q == "SELECT $p1" + assert p["$p1"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64) + + def test_basic_tuple(self): + q, p = convert_query_parameters("SELECT %s", (42,)) + assert q == "SELECT $p1" + assert p["$p1"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64) + + def test_multiple_params_numbered_sequentially(self): + q, p = convert_query_parameters( + "INSERT INTO t VALUES (%s, %s, %s)", [1, "hi", 3.14] + ) + assert q == "INSERT INTO t VALUES ($p1, $p2, $p3)" + assert p["$p1"] == ydb.TypedValue(1, ydb.PrimitiveType.Int64) + assert p["$p2"] == ydb.TypedValue("hi", ydb.PrimitiveType.Utf8) + assert p["$p3"] == ydb.TypedValue(3.14, ydb.PrimitiveType.Double) + + def test_none_passed_as_is(self): + _, p = convert_query_parameters("SELECT %s", [None]) + assert p["$p1"] is None + + def test_percent_percent_escape(self): + q, p = convert_query_parameters("SELECT %%, %s", [7]) + assert q == "SELECT %, $p1" + assert p["$p1"] == ydb.TypedValue(7, ydb.PrimitiveType.Int64) + + def test_empty_list(self): + q, p = convert_query_parameters("SELECT 1", []) + assert q == "SELECT 1" + assert p == {} diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index 1ee215b..6bc9b4d 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -82,6 +82,7 @@ def __init__( root_certificates_path: str | None = None, root_certificates: str | None = None, driver_config_kwargs: dict | None = None, + pyformat: bool = False, **kwargs: Any, ) -> None: protocol = protocol if protocol else "grpc" @@ -89,6 +90,7 @@ def __init__( self.credentials = prepare_credentials(credentials) self.database = database self.table_path_prefix = ydb_table_path_prefix + self.pyformat = pyformat self.connection_kwargs: dict = kwargs @@ -216,6 +218,7 @@ def __init__( root_certificates_path: str | None = None, root_certificates: str | None = None, driver_config_kwargs: dict | None = None, + pyformat: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -229,6 +232,7 @@ def __init__( root_certificates_path=root_certificates_path, root_certificates=root_certificates, driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, **kwargs, ) self._current_cursor: Cursor | None = None @@ -242,6 +246,7 @@ def cursor(self) -> Cursor: table_path_prefix=self.table_path_prefix, request_settings=self.request_settings, retry_settings=self.retry_settings, + pyformat=self.pyformat, ) def wait_ready(self, timeout: int = 10) -> None: @@ -411,6 +416,7 @@ def __init__( root_certificates_path: str | None = None, root_certificates: str | None = None, driver_config_kwargs: dict | None = None, + pyformat: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -424,6 +430,7 @@ def __init__( root_certificates_path=root_certificates_path, root_certificates=root_certificates, driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, **kwargs, ) self._current_cursor: AsyncCursor | None = None @@ -437,6 +444,7 @@ def cursor(self) -> AsyncCursor: table_path_prefix=self.table_path_prefix, request_settings=self.request_settings, retry_settings=self.retry_settings, + pyformat=self.pyformat, ) async def wait_ready(self, timeout: int = 10) -> None: @@ -593,13 +601,65 @@ async def _invalidate_session(self) -> None: self._session = None -def connect(*args: Any, **kwargs: Any) -> Connection: - conn = Connection(*args, **kwargs) +def connect( + host: str = "", + port: str = "", + database: str = "", + ydb_table_path_prefix: str = "", + protocol: str | None = None, + credentials: ydb.Credentials | dict | str | None = None, + ydb_session_pool: SessionPool | AsyncSessionPool | None = None, + root_certificates_path: str | None = None, + root_certificates: str | None = None, + driver_config_kwargs: dict | None = None, + pyformat: bool = False, + **kwargs: Any, +) -> Connection: + conn = Connection( + host=host, + port=port, + database=database, + ydb_table_path_prefix=ydb_table_path_prefix, + protocol=protocol, + credentials=credentials, + ydb_session_pool=ydb_session_pool, + root_certificates_path=root_certificates_path, + root_certificates=root_certificates, + driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, + **kwargs, + ) conn.wait_ready() return conn -async def async_connect(*args: Any, **kwargs: Any) -> AsyncConnection: - conn = AsyncConnection(*args, **kwargs) +async def async_connect( + host: str = "", + port: str = "", + database: str = "", + ydb_table_path_prefix: str = "", + protocol: str | None = None, + credentials: ydb.Credentials | dict | str | None = None, + ydb_session_pool: SessionPool | AsyncSessionPool | None = None, + root_certificates_path: str | None = None, + root_certificates: str | None = None, + driver_config_kwargs: dict | None = None, + pyformat: bool = False, + **kwargs: Any, +) -> AsyncConnection: + conn = AsyncConnection( + host=host, + port=port, + database=database, + ydb_table_path_prefix=ydb_table_path_prefix, + protocol=protocol, + credentials=credentials, + ydb_session_pool=ydb_session_pool, + root_certificates_path=root_certificates_path, + root_certificates=root_certificates, + driver_config_kwargs=driver_config_kwargs, + pyformat=pyformat, + **kwargs, + ) await conn.wait_ready() return conn diff --git a/ydb_dbapi/cursors.py b/ydb_dbapi/cursors.py index 7060401..fa04936 100644 --- a/ydb_dbapi/cursors.py +++ b/ydb_dbapi/cursors.py @@ -19,6 +19,7 @@ from .errors import InterfaceError from .errors import ProgrammingError from .utils import CursorStatus +from .utils import convert_query_parameters from .utils import handle_ydb_errors from .utils import maybe_get_current_trace_id @@ -26,13 +27,17 @@ from .connections import AsyncConnection from .connections import Connection - ParametersType = dict[ - str, - Union[ - Any, - tuple[Any, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]], - ydb.TypedValue, + ParametersType = Union[ + dict[ + str, + Union[ + Any, + tuple[Any, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]], + ydb.TypedValue, + ], ], + list[Any], + tuple[Any, ...], ] @@ -202,6 +207,7 @@ def __init__( retry_settings: ydb.RetrySettings, tx_context: ydb.QueryTxContext | None = None, table_path_prefix: str = "", + pyformat: bool = False, ) -> None: super().__init__() self._connection = connection @@ -211,6 +217,7 @@ def __init__( self._retry_settings = retry_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix + self._pyformat = pyformat self._stream: Iterator | None = None def fetchone(self) -> tuple | None: @@ -328,6 +335,10 @@ def execute( self._raise_if_running() query = self._append_table_path_prefix(query) + + if self._pyformat and parameters is not None: + query, parameters = convert_query_parameters(query, parameters) + self._begin_query() if self._tx_context is not None: @@ -379,6 +390,7 @@ def __init__( retry_settings: ydb.RetrySettings, tx_context: ydb.aio.QueryTxContext | None = None, table_path_prefix: str = "", + pyformat: bool = False, ) -> None: super().__init__() self._connection = connection @@ -388,6 +400,7 @@ def __init__( self._retry_settings = retry_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix + self._pyformat = pyformat self._stream: AsyncIterator | None = None def fetchone(self) -> tuple | None: @@ -506,6 +519,9 @@ async def execute( query = self._append_table_path_prefix(query) + if self._pyformat and parameters is not None: + query, parameters = convert_query_parameters(query, parameters) + self._begin_query() if self._tx_context is not None: diff --git a/ydb_dbapi/utils.py b/ydb_dbapi/utils.py index 5e0e1f0..fc09a44 100644 --- a/ydb_dbapi/utils.py +++ b/ydb_dbapi/utils.py @@ -1,8 +1,11 @@ from __future__ import annotations +import datetime +import decimal import functools import importlib.util import json +import re from enum import Enum from inspect import iscoroutinefunction from typing import Any @@ -156,3 +159,105 @@ def prepare_credentials( ) return ydb.AnonymousCredentials() + + +# Order matters: bool before int, datetime before date (subclass checks). +_PYTHON_TO_YDB_TYPE: list[tuple[type, Any]] = [ + (bool, ydb.PrimitiveType.Bool), + (int, ydb.PrimitiveType.Int64), + (float, ydb.PrimitiveType.Double), + (str, ydb.PrimitiveType.Utf8), + (bytes, ydb.PrimitiveType.String), + (datetime.datetime, ydb.PrimitiveType.Timestamp), + (datetime.date, ydb.PrimitiveType.Date), + (datetime.timedelta, ydb.PrimitiveType.Interval), + (decimal.Decimal, ydb.DecimalType(22, 9)), +] + + +def _infer_ydb_type(value: Any) -> Any: + """Infer a YDB type from a Python value.""" + for python_type, ydb_type in _PYTHON_TO_YDB_TYPE: + if isinstance(value, python_type): + return ydb_type + return None + + +def _wrap_value(value: Any) -> Any: + """Wrap a Python value in ydb.TypedValue if a type can be inferred. + + ``ydb.TypedValue`` instances are returned as-is so callers can supply + an explicit type for values whose type cannot be inferred automatically. + """ + if isinstance(value, ydb.TypedValue): + return value + if value is None: + return value + ydb_type = _infer_ydb_type(value) + if ydb_type is not None: + return ydb.TypedValue(value, ydb_type) + return value + + +def convert_query_parameters( + query: str, + parameters: dict | list | tuple, +) -> tuple[str, dict]: + """Convert pyformat-style query and parameters to YDB format. + + Supports two parameter styles: + + Named (``%(name)s``) with a mapping:: + + convert_query_parameters( + "SELECT %(id)s", {"id": 42} + ) + # -> ("SELECT $id", {"$id": TypedValue(42, Int64)}) + + Positional (``%s``) with a sequence:: + + convert_query_parameters( + "SELECT %s, %s", [42, "hi"] + ) + # -> ("SELECT $p1, $p2", {"$p1": TypedValue(42, Int64), + # "$p2": TypedValue("hi", Utf8)}) + + ``%%`` is converted to a literal ``%`` in both modes. + + Python-to-YDB type mapping: + bool -> Bool + int -> Int64 + float -> Double + str -> Utf8 + bytes -> String + datetime -> Timestamp + date -> Date + timedelta -> Interval + Decimal -> Decimal(22, 9) + None -> passed as-is (NULL) + """ + positional_index = 0 + + def replace(m: re.Match) -> str: + nonlocal positional_index + full = m.group(0) + if full == "%%": + return "%" + if full.startswith("%("): + return f"${m.group(1)}" + # %s — positional + positional_index += 1 + return f"$p{positional_index}" + + converted_query = re.sub(r"%%|%\((\w+)\)s|%s", replace, query) + + converted_params: dict = {} + + if isinstance(parameters, (list, tuple)): + for i, value in enumerate(parameters, start=1): + converted_params[f"$p{i}"] = _wrap_value(value) + else: + for name, value in parameters.items(): + converted_params[f"${name}"] = _wrap_value(value) + + return converted_query, converted_params