Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,94 @@ async with async_connection.cursor() as cursor:
rows = await cursor.fetchmany(size=5)
rows = await cursor.fetchall()
```

## Query parameters

### Standard mode (`pyformat=True`)

Pass `pyformat=True` to `connect()` to use familiar Python DB-API
parameter syntax. The driver will convert placeholders and infer YDB types
from Python values automatically.

**Named parameters** — `%(name)s` with a `dict`:

```python
connection = ydb_dbapi.connect(
host="localhost", port="2136", database="/local",
pyformat=True,
)

with connection.cursor() as cursor:
cursor.execute(
"SELECT * FROM users WHERE id = %(id)s AND active = %(active)s",
{"id": 42, "active": True},
)
```

**Positional parameters** — `%s` with a `list` or `tuple`:

```python
with connection.cursor() as cursor:
cursor.execute(
"INSERT INTO users (id, name, score) VALUES (%s, %s, %s)",
[1, "Alice", 9.8],
)
```

Use `%%` to insert a literal `%` character in the query.

**Automatic type mapping:**

| Python type | YDB type |
|--------------------|-------------|
| `bool` | `Bool` |
| `int` | `Int64` |
| `float` | `Double` |
| `str` | `Utf8` |
| `bytes` | `String` |
| `datetime.datetime`| `Timestamp` |
| `datetime.date` | `Date` |
| `datetime.timedelta`| `Interval` |
| `decimal.Decimal` | `Decimal(22, 9)` |
| `None` | `NULL` (passed as-is) |

**Explicit types with `ydb.TypedValue`:**

When automatic inference is not suitable (e.g. you need `Int32` instead of
`Int64`, or `Json`), wrap the value in `ydb.TypedValue` — it will be passed
through unchanged:

```python
import ydb

with connection.cursor() as cursor:
cursor.execute(
"INSERT INTO events (id, payload) VALUES (%(id)s, %(payload)s)",
{
"id": ydb.TypedValue(99, ydb.PrimitiveType.Int32),
"payload": ydb.TypedValue('{"key": "value"}', ydb.PrimitiveType.Json),
},
)
```

### Native YDB mode (default, deprecated)

> **Deprecated.** Native YDB mode is the current default for backwards
> compatibility, but it will be removed in a future release. Migrate to
> `pyformat=True` at your earliest convenience.

By default (`pyformat=False`) the driver passes the query and parameters
directly to the YDB SDK without any transformation. Use `$name` placeholders
in the query and supply a `dict` with `$`-prefixed keys:

```python
connection = ydb_dbapi.connect(
host="localhost", port="2136", database="/local",
)

with connection.cursor() as cursor:
cursor.execute(
"SELECT * FROM users WHERE id = $id",
{"$id": ydb.TypedValue(42, ydb.PrimitiveType.Int64)},
)
```
170 changes: 170 additions & 0 deletions tests/test_convert_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

import datetime
import decimal

import ydb
from ydb_dbapi.utils import convert_query_parameters


class TestNamedStyle:
"""%(name)s placeholders with a dict."""

def test_basic_query_transformation(self):
q, _ = convert_query_parameters(
"SELECT %(id)s FROM t WHERE name = %(name)s",
{"id": 1, "name": "alice"},
)
assert q == "SELECT $id FROM t WHERE name = $name"

def test_keys_prefixed_with_dollar(self):
_, p = convert_query_parameters("SELECT %(x)s", {"x": 1})
assert "$x" in p
assert "x" not in p

def test_int(self):
_, p = convert_query_parameters("SELECT %(x)s", {"x": 42})
assert p["$x"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64)

def test_float(self):
_, p = convert_query_parameters("SELECT %(x)s", {"x": 3.14})
assert p["$x"] == ydb.TypedValue(3.14, ydb.PrimitiveType.Double)

def test_str(self):
_, p = convert_query_parameters("SELECT %(x)s", {"x": "hello"})
assert p["$x"] == ydb.TypedValue("hello", ydb.PrimitiveType.Utf8)

def test_bytes(self):
_, p = convert_query_parameters("SELECT %(x)s", {"x": b"data"})
assert p["$x"] == ydb.TypedValue(b"data", ydb.PrimitiveType.String)

def test_bool(self):
_, p = convert_query_parameters("SELECT %(x)s", {"x": True})
assert p["$x"] == ydb.TypedValue(True, ydb.PrimitiveType.Bool)

def test_bool_not_confused_with_int(self):
# bool is subclass of int — must map to Bool, not Int64
_, p = convert_query_parameters("SELECT %(x)s", {"x": False})
assert p["$x"].value_type == ydb.PrimitiveType.Bool

def test_date(self):
d = datetime.date(2024, 1, 15)
_, p = convert_query_parameters("SELECT %(x)s", {"x": d})
assert p["$x"] == ydb.TypedValue(d, ydb.PrimitiveType.Date)

def test_datetime(self):
tz = datetime.timezone.utc
dt = datetime.datetime(2024, 1, 15, 12, 0, 0, tzinfo=tz)
_, p = convert_query_parameters("SELECT %(x)s", {"x": dt})
assert p["$x"] == ydb.TypedValue(dt, ydb.PrimitiveType.Timestamp)

def test_datetime_not_confused_with_date(self):
# datetime is subclass of date — must map to Timestamp, not Date
tz = datetime.timezone.utc
dt = datetime.datetime(2024, 6, 1, 0, 0, 0, tzinfo=tz)
_, p = convert_query_parameters("SELECT %(x)s", {"x": dt})
assert p["$x"].value_type == ydb.PrimitiveType.Timestamp

def test_timedelta(self):
td = datetime.timedelta(seconds=60)
_, p = convert_query_parameters("SELECT %(x)s", {"x": td})
assert p["$x"] == ydb.TypedValue(td, ydb.PrimitiveType.Interval)

def test_decimal(self):
d = decimal.Decimal("3.14")
_, p = convert_query_parameters("SELECT %(x)s", {"x": d})
assert isinstance(p["$x"], ydb.TypedValue)
assert p["$x"].value == d

def test_none_passed_as_is(self):
_, p = convert_query_parameters("SELECT %(x)s", {"x": None})
assert p["$x"] is None

def test_unknown_type_passed_as_is(self):
obj = object()
_, p = convert_query_parameters("SELECT %(x)s", {"x": obj})
assert p["$x"] is obj

def test_multiple_params(self):
q, p = convert_query_parameters(
"INSERT INTO t VALUES (%(a)s, %(b)s, %(c)s)",
{"a": 1, "b": "hi", "c": True},
)
assert q == "INSERT INTO t VALUES ($a, $b, $c)"
assert "$a" in p
assert "$b" in p
assert "$c" in p

def test_percent_percent_escape(self):
q, _ = convert_query_parameters(
"SELECT %% as pct, %(x)s", {"x": 1}
)
assert q == "SELECT % as pct, $x"

def test_empty_params(self):
q, p = convert_query_parameters("SELECT 1", {})
assert q == "SELECT 1"
assert p == {}


class TestCustomTypes:
"""Pass-through for ydb.TypedValue (explicit type hint)."""

def test_typed_value_passed_through(self):
tv = ydb.TypedValue(42, ydb.PrimitiveType.Int32)
_, p = convert_query_parameters("SELECT %(x)s", {"x": tv})
assert p["$x"] is tv

def test_typed_value_not_double_wrapped(self):
tv = ydb.TypedValue("hello", ydb.PrimitiveType.Utf8)
_, p = convert_query_parameters("SELECT %(x)s", {"x": tv})
assert isinstance(p["$x"], ydb.TypedValue)
assert p["$x"].value_type == ydb.PrimitiveType.Utf8

def test_typed_value_positional(self):
tv = ydb.TypedValue(99, ydb.PrimitiveType.Int32)
_, p = convert_query_parameters("SELECT %s", [tv])
assert p["$p1"] is tv

def test_unknown_type_passed_as_is(self):
# No TypedValue, no known type — value goes through unchanged
val = object()
_, p = convert_query_parameters("SELECT %(x)s", {"x": val})
assert p["$x"] is val


class TestPositionalStyle:
"""Positional %s placeholders with a list or tuple."""

def test_basic_list(self):
q, p = convert_query_parameters("SELECT %s", [42])
assert q == "SELECT $p1"
assert p["$p1"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64)

def test_basic_tuple(self):
q, p = convert_query_parameters("SELECT %s", (42,))
assert q == "SELECT $p1"
assert p["$p1"] == ydb.TypedValue(42, ydb.PrimitiveType.Int64)

def test_multiple_params_numbered_sequentially(self):
q, p = convert_query_parameters(
"INSERT INTO t VALUES (%s, %s, %s)", [1, "hi", 3.14]
)
assert q == "INSERT INTO t VALUES ($p1, $p2, $p3)"
assert p["$p1"] == ydb.TypedValue(1, ydb.PrimitiveType.Int64)
assert p["$p2"] == ydb.TypedValue("hi", ydb.PrimitiveType.Utf8)
assert p["$p3"] == ydb.TypedValue(3.14, ydb.PrimitiveType.Double)

def test_none_passed_as_is(self):
_, p = convert_query_parameters("SELECT %s", [None])
assert p["$p1"] is None

def test_percent_percent_escape(self):
q, p = convert_query_parameters("SELECT %%, %s", [7])
assert q == "SELECT %, $p1"
assert p["$p1"] == ydb.TypedValue(7, ydb.PrimitiveType.Int64)

def test_empty_list(self):
q, p = convert_query_parameters("SELECT 1", [])
assert q == "SELECT 1"
assert p == {}
68 changes: 64 additions & 4 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ def __init__(
root_certificates_path: str | None = None,
root_certificates: str | None = None,
driver_config_kwargs: dict | None = None,
pyformat: bool = False,
**kwargs: Any,
) -> None:
protocol = protocol if protocol else "grpc"
self.endpoint = f"{protocol}://{host}:{port}"
self.credentials = prepare_credentials(credentials)
self.database = database
self.table_path_prefix = ydb_table_path_prefix
self.pyformat = pyformat

self.connection_kwargs: dict = kwargs

Expand Down Expand Up @@ -216,6 +218,7 @@ def __init__(
root_certificates_path: str | None = None,
root_certificates: str | None = None,
driver_config_kwargs: dict | None = None,
pyformat: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -229,6 +232,7 @@ def __init__(
root_certificates_path=root_certificates_path,
root_certificates=root_certificates,
driver_config_kwargs=driver_config_kwargs,
pyformat=pyformat,
**kwargs,
)
self._current_cursor: Cursor | None = None
Expand All @@ -242,6 +246,7 @@ def cursor(self) -> Cursor:
table_path_prefix=self.table_path_prefix,
request_settings=self.request_settings,
retry_settings=self.retry_settings,
pyformat=self.pyformat,
)

def wait_ready(self, timeout: int = 10) -> None:
Expand Down Expand Up @@ -411,6 +416,7 @@ def __init__(
root_certificates_path: str | None = None,
root_certificates: str | None = None,
driver_config_kwargs: dict | None = None,
pyformat: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -424,6 +430,7 @@ def __init__(
root_certificates_path=root_certificates_path,
root_certificates=root_certificates,
driver_config_kwargs=driver_config_kwargs,
pyformat=pyformat,
**kwargs,
)
self._current_cursor: AsyncCursor | None = None
Expand All @@ -437,6 +444,7 @@ def cursor(self) -> AsyncCursor:
table_path_prefix=self.table_path_prefix,
request_settings=self.request_settings,
retry_settings=self.retry_settings,
pyformat=self.pyformat,
)

async def wait_ready(self, timeout: int = 10) -> None:
Expand Down Expand Up @@ -593,13 +601,65 @@ async def _invalidate_session(self) -> None:
self._session = None


def connect(*args: Any, **kwargs: Any) -> Connection:
conn = Connection(*args, **kwargs)
def connect(
host: str = "",
port: str = "",
database: str = "",
ydb_table_path_prefix: str = "",
protocol: str | None = None,
credentials: ydb.Credentials | dict | str | None = None,
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
root_certificates_path: str | None = None,
root_certificates: str | None = None,
driver_config_kwargs: dict | None = None,
pyformat: bool = False,
**kwargs: Any,
) -> Connection:
conn = Connection(
host=host,
port=port,
database=database,
ydb_table_path_prefix=ydb_table_path_prefix,
protocol=protocol,
credentials=credentials,
ydb_session_pool=ydb_session_pool,
root_certificates_path=root_certificates_path,
root_certificates=root_certificates,
driver_config_kwargs=driver_config_kwargs,
pyformat=pyformat,
**kwargs,
)
conn.wait_ready()
return conn


async def async_connect(*args: Any, **kwargs: Any) -> AsyncConnection:
conn = AsyncConnection(*args, **kwargs)
async def async_connect(
host: str = "",
port: str = "",
database: str = "",
ydb_table_path_prefix: str = "",
protocol: str | None = None,
credentials: ydb.Credentials | dict | str | None = None,
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
root_certificates_path: str | None = None,
root_certificates: str | None = None,
driver_config_kwargs: dict | None = None,
pyformat: bool = False,
**kwargs: Any,
) -> AsyncConnection:
conn = AsyncConnection(
host=host,
port=port,
database=database,
ydb_table_path_prefix=ydb_table_path_prefix,
protocol=protocol,
credentials=credentials,
ydb_session_pool=ydb_session_pool,
root_certificates_path=root_certificates_path,
root_certificates=root_certificates,
driver_config_kwargs=driver_config_kwargs,
pyformat=pyformat,
**kwargs,
)
await conn.wait_ready()
return conn
Loading
Loading