From 2caf0baf7483d1772ae81a607d9e8decab156baa Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 02:08:21 +0000 Subject: [PATCH 1/8] Add chdb in-process backend via interface="chdb" --- clickhouse_connect/driver/__init__.py | 61 ++ clickhouse_connect/driver/chdbasync.py | 327 ++++++++ clickhouse_connect/driver/chdbclient.py | 695 ++++++++++++++++++ setup.py | 1 + tests/test_bare_install.py | 26 + tests/test_requirements.txt | 1 + .../unit_tests/test_driver/test_chdbclient.py | 320 ++++++++ 7 files changed, 1431 insertions(+) create mode 100644 clickhouse_connect/driver/chdbasync.py create mode 100644 clickhouse_connect/driver/chdbclient.py create mode 100644 tests/unit_tests/test_driver/test_chdbclient.py diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index d0fe9f1d..cb42bcd4 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -159,6 +159,14 @@ def create_client( limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect Client instance """ + if interface == "chdb": + return _create_chdb_client( + database=database, + settings=settings, + generic_args=generic_args, + kwargs=kwargs, + ) + host, username, password, port, database, interface = _parse_connection_params( host, username, password, port, database, interface, secure, dsn, kwargs ) @@ -264,6 +272,14 @@ async def create_async_client( limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect AsyncClient instance """ + if interface == "chdb": + return _create_chdb_async_client( + database=database, + settings=settings, + generic_args=generic_args, + kwargs=kwargs, + ) + try: from clickhouse_connect.driver.asyncclient import AsyncClient as _AsyncClient except ModuleNotFoundError as ex: @@ -315,3 +331,48 @@ async def create_async_client( ) await client._initialize() return client + + +def _create_chdb_client( + *, + database: str, + settings: dict[str, Any] | None, + generic_args: dict[str, Any] | None, + kwargs: dict[str, Any], +) -> Client: + try: + from clickhouse_connect.driver.chdbclient import ChdbClient + except ImportError as ex: + if ex.name == "chdb" or (ex.name and ex.name.startswith("chdb")): + raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex + raise + + settings = dict(settings or {}) + if generic_args: + for name, value in generic_args.items(): + if name.startswith("ch_"): + name = name[3:] + settings[name] = value + return ChdbClient( + database=database, + settings=settings, + **kwargs, + ) + + +def _create_chdb_async_client( + *, + database: str, + settings: dict[str, Any] | None, + generic_args: dict[str, Any] | None, + kwargs: dict[str, Any], +): + try: + from clickhouse_connect.driver.chdbasync import AsyncChdbClient + except ImportError as ex: + if ex.name == "chdb" or (ex.name and ex.name.startswith("chdb")): + raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex + raise + + sync_client = _create_chdb_client(database=database, settings=settings, generic_args=generic_args, kwargs=kwargs) + return AsyncChdbClient(sync_client) # type: ignore[arg-type] diff --git a/clickhouse_connect/driver/chdbasync.py b/clickhouse_connect/driver/chdbasync.py new file mode 100644 index 00000000..d4998cc5 --- /dev/null +++ b/clickhouse_connect/driver/chdbasync.py @@ -0,0 +1,327 @@ +""" +Async wrapper around ChdbClient. + +chdb has no native async API, so this client delegates each call to the wrapped +sync ChdbClient via `asyncio.get_running_loop().run_in_executor(...)`. Because +ChdbClient serializes concurrent calls on a per-client `threading.Lock`, +gather()-style concurrency on a single AsyncChdbClient does not actually run in +parallel — for true parallelism, create multiple clients. +""" + +from __future__ import annotations + +import asyncio +import io +import logging +from collections.abc import Generator, Iterable, Sequence +from datetime import tzinfo +from typing import TYPE_CHECKING, Any, BinaryIO + +from clickhouse_connect.datatypes.base import ClickHouseType +from clickhouse_connect.driver.chdbclient import ChdbClient +from clickhouse_connect.driver.client import Client +from clickhouse_connect.driver.common import StreamContext +from clickhouse_connect.driver.external import ExternalData +from clickhouse_connect.driver.insert import InsertContext +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode +from clickhouse_connect.driver.summary import QuerySummary + +if TYPE_CHECKING: + import numpy + import pandas + import polars + import pyarrow + +logger = logging.getLogger(__name__) + + +class AsyncChdbClient(Client): + """ + Async-facing client for the in-process chdb backend. Each public coroutine + schedules the corresponding sync ChdbClient call on the default thread + executor. Sync-only methods (settings, min_version) are passed through + directly. + """ + + valid_transport_settings: set[str] = ChdbClient.valid_transport_settings + optional_transport_settings: set[str] = ChdbClient.optional_transport_settings + + def __init__(self, sync: ChdbClient): + self._sync = sync + # Mirror attributes commonly read off the client object so user code that + # touches them (server_version, server_tz, database, etc.) keeps working. + self.server_tz = sync.server_tz + self.server_version = sync.server_version + self.server_settings = sync.server_settings + self.database = sync.database + self.uri = sync.uri + self.query_limit = sync.query_limit + self.query_retries = sync.query_retries + self.tz_mode = sync.tz_mode + self._tz_source = sync._tz_source + self._apply_server_tz = sync._apply_server_tz + self._dst_safe = sync._dst_safe + self.show_clickhouse_errors = sync.show_clickhouse_errors + self.protocol_version = sync.protocol_version + self.write_compression = sync.write_compression + self.compression = sync.compression + self._read_format = sync._read_format + self._write_format = sync._write_format + self._transform = sync._transform + + @property + def chdb_connection(self): + return self._sync.chdb_connection + + async def _run(self, func, *args, **kwargs): + loop = asyncio.get_running_loop() + if kwargs: + return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) + return await loop.run_in_executor(None, func, *args) + + # ---- sync passthroughs (no I/O) ---- + + def set_client_setting(self, key: str, value: Any) -> None: + self._sync.set_client_setting(key, value) + + def get_client_setting(self, key: str) -> str | None: + return self._sync.get_client_setting(key) + + def set_access_token(self, access_token: str) -> None: + self._sync.set_access_token(access_token) + + def min_version(self, version_str: str) -> bool: + return self._sync.min_version(version_str) + + # ---- async overrides ---- + + async def _query_with_context(self, context: QueryContext) -> QueryResult: # type: ignore[override] + return await self._run(self._sync._query_with_context, context) + + async def query( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + column_oriented: bool | None = None, + use_numpy: bool | None = None, + max_str_len: int | None = None, + context: QueryContext | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> QueryResult: + return await self._run( + lambda: self._sync.query( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + context=context, + query_tz=query_tz, + column_tzs=column_tzs, + external_data=external_data, + transport_settings=transport_settings, + tz_mode=tz_mode, + ) + ) + + async def query_column_block_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_column_block_stream(*args, **kwargs)) + + async def query_row_block_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_row_block_stream(*args, **kwargs)) + + async def query_rows_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_rows_stream(*args, **kwargs)) + + async def query_np(self, *args, **kwargs) -> numpy.ndarray: + return await self._run(lambda: self._sync.query_np(*args, **kwargs)) + + async def query_np_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_np_stream(*args, **kwargs)) + + async def query_df(self, *args, **kwargs) -> pandas.DataFrame: + return await self._run(lambda: self._sync.query_df(*args, **kwargs)) + + async def query_df_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_df_stream(*args, **kwargs)) + + async def query_arrow(self, *args, **kwargs) -> pyarrow.Table: + return await self._run(lambda: self._sync.query_arrow(*args, **kwargs)) + + async def query_arrow_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_arrow_stream(*args, **kwargs)) + + async def query_df_arrow(self, *args, **kwargs) -> pandas.DataFrame | polars.DataFrame: + return await self._run(lambda: self._sync.query_df_arrow(*args, **kwargs)) + + async def query_df_arrow_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_df_arrow_stream(*args, **kwargs)) + + async def command( # type: ignore[override] + self, + cmd: str, + parameters: Sequence | dict[str, Any] | None = None, + data: str | bytes | None = None, + settings: dict[str, Any] | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> str | int | Sequence[str] | QuerySummary: + return await self._run( + lambda: self._sync.command( + cmd, + parameters=parameters, + data=data, + settings=settings, + use_database=use_database, + external_data=external_data, + transport_settings=transport_settings, + ) + ) + + async def ping(self) -> bool: # type: ignore[override] + return await self._run(self._sync.ping) + + async def raw_query( # type: ignore[override] + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> bytes: + return await self._run( + lambda: self._sync.raw_query( + query, + parameters=parameters, + settings=settings, + fmt=fmt, + use_database=use_database, + external_data=external_data, + transport_settings=transport_settings, + ) + ) + + async def raw_stream( # type: ignore[override] + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> io.IOBase: + return await self._run( + lambda: self._sync.raw_stream( + query, + parameters=parameters, + settings=settings, + fmt=fmt, + use_database=use_database, + external_data=external_data, + transport_settings=transport_settings, + ) + ) + + async def insert( # type: ignore[override] + self, + table: str | None = None, + data=None, + column_names: str | Iterable[str] = "*", + database: str | None = None, + column_types: Sequence[ClickHouseType] | None = None, + column_type_names: Sequence[str] | None = None, + column_oriented: bool = False, + settings: dict[str, Any] | None = None, + context: InsertContext | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + return await self._run( + lambda: self._sync.insert( + table=table, + data=data, + column_names=column_names, + database=database, + column_types=column_types, + column_type_names=column_type_names, + column_oriented=column_oriented, + settings=settings, + context=context, + transport_settings=transport_settings, + ) + ) + + async def insert_df(self, *args, **kwargs) -> QuerySummary: # type: ignore[override] + return await self._run(lambda: self._sync.insert_df(*args, **kwargs)) + + async def insert_arrow(self, *args, **kwargs) -> QuerySummary: # type: ignore[override] + return await self._run(lambda: self._sync.insert_arrow(*args, **kwargs)) + + async def insert_df_arrow(self, *args, **kwargs) -> QuerySummary: # type: ignore[override] + return await self._run(lambda: self._sync.insert_df_arrow(*args, **kwargs)) + + async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ignore[override] + return await self._run(self._sync.data_insert, context) + + async def raw_insert( # type: ignore[override] + self, + table: str | None = None, + column_names: Sequence[str] | None = None, + insert_block: str | bytes | Generator[bytes, None, None] | BinaryIO | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + compression: str | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + return await self._run( + lambda: self._sync.raw_insert( + table=table, + column_names=column_names, + insert_block=insert_block, + settings=settings, + fmt=fmt, + compression=compression, + transport_settings=transport_settings, + ) + ) + + async def close(self) -> None: # type: ignore[override] + await self._run(self._sync.close) + + async def close_connections(self) -> None: # type: ignore[override] + await self._run(self._sync.close_connections) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() + return False + + # Some helper methods on Client (like create_insert_context, create_query_context) + # do synchronous local work and call self.query/self.command for schema lookup. We + # can't await inside a sync method, so users should normally rely on insert/query + # which we already async-wrap. + + def create_insert_context(self, *args, **kwargs): + return self._sync.create_insert_context(*args, **kwargs) + + def create_query_context(self, *args, **kwargs): + return self._sync.create_query_context(*args, **kwargs) diff --git a/clickhouse_connect/driver/chdbclient.py b/clickhouse_connect/driver/chdbclient.py new file mode 100644 index 00000000..9e84c25c --- /dev/null +++ b/clickhouse_connect/driver/chdbclient.py @@ -0,0 +1,695 @@ +""" +In-process chdb backend for clickhouse-connect. + +ChdbClient implements the Client contract on top of the embedded ClickHouse engine +exposed by the `chdb` Python package. The same Native byte format that the HTTP +server emits is consumed verbatim, so all of clickhouse-connect's existing type, +dtype, and result conversion machinery is reused. +""" + +from __future__ import annotations + +import io +import logging +import os +import sys +import tempfile +import threading +import uuid +from collections.abc import Generator, Sequence +from typing import TYPE_CHECKING, Any, BinaryIO + +from clickhouse_connect import common +from clickhouse_connect.driver import options +from clickhouse_connect.driver.binding import bind_query, quote_identifier +from clickhouse_connect.driver.client import Client +from clickhouse_connect.driver.common import coerce_int +from clickhouse_connect.driver.ctypes import RespBuffCls +from clickhouse_connect.driver.exceptions import ( + DatabaseError, + NotSupportedError, + ProgrammingError, +) +from clickhouse_connect.driver.external import ExternalData +from clickhouse_connect.driver.insert import InsertContext +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode, TzSource +from clickhouse_connect.driver.summary import QuerySummary +from clickhouse_connect.driver.transform import NativeTransform + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# HTTP-only kwargs accepted (and ignored) so users can switch interface without +# editing the rest of their connection config. +_HTTP_ONLY_KWARGS = frozenset( + { + "compress", + "compression", + "connect_timeout", + "send_receive_timeout", + "client_name", + "verify", + "ca_cert", + "client_cert", + "client_cert_key", + "session_id", + "pool_mgr", + "http_proxy", + "https_proxy", + "tls_mode", + "proxy_path", + "form_encode_query_params", + "rename_response_column", + "autogenerate_session_id", + "autogenerate_query_id", + "connector_limit", + "connector_limit_per_host", + "keepalive_timeout", + "server_host_name", + } +) + + +class _BytesSource: + """ + Minimal stand-in for the HTTP `ResponseSource` that the response buffer + expects. Yields a single chunk of bytes and exposes the attributes the + transform layer reads. + """ + + __slots__ = ("data", "last_message", "exception_tag") + + def __init__(self, data: bytes): + self.data = data + self.last_message = None + self.exception_tag = None + + @property + def gen(self): + def _gen(): + yield self.data + + return _gen() + + def close(self): + return None + + +class _ChunkIterSource: + """Source backed by an iterator of byte chunks, used for streaming reads.""" + + __slots__ = ("_chunks", "last_message", "exception_tag") + + def __init__(self, chunks): + self._chunks = iter(chunks) + self.last_message = None + self.exception_tag = None + + @property + def gen(self): + return self._chunks + + def close(self): + try: + close = getattr(self._chunks, "close", None) + if close: + close() + except Exception: # noqa: BLE001 + pass + + +# Module globals used to expose user-provided Python objects (DataFrames, PyArrow +# tables) to chdb's `Python(name)` table function. chdb walks frames and module +# globals looking for the bare name passed to `Python(...)`, so we register +# objects under a uuid-suffixed name and clean up afterwards. +_chdb_ref_lock = threading.Lock() + + +def _register_chdb_object(obj) -> str: + name = f"_chdb_ref_{uuid.uuid4().hex}" + with _chdb_ref_lock: + globals()[name] = obj + return name + + +def _unregister_chdb_object(name: str) -> None: + with _chdb_ref_lock: + globals().pop(name, None) + + +def _format_error_message(message: str) -> str: + """Extract a clean ClickHouse exception message from a chdb error string.""" + if not message: + return "" + idx = message.find("Code: ") + if idx > 0: + return message[idx:].strip() + return message.strip() + + +def _build_conn_string(chdb_path: str, chdb_options: dict[str, Any] | None) -> str: + if not chdb_path or chdb_path in (":memory:", "memory"): + path = ":memory:" + elif chdb_path.startswith("file:"): + return chdb_path + else: + path = chdb_path + if not chdb_options: + return path + from urllib.parse import urlencode + + query = urlencode({k: str(v) for k, v in chdb_options.items()}) + if path == ":memory:": + return f"file::memory:?{query}" + return f"file:{path}?{query}" + + +class ChdbClient(Client): + """ClickHouse Connect client backed by the in-process chdb engine.""" + + # HTTP-style transport settings: accepted by setting validation but stripped + # before being forwarded to chdb (they have no in-process equivalent). + valid_transport_settings: set[str] = { + "client_protocol_version", + "session_id", + "session_timeout", + "session_check", + "query_id", + "quota_key", + "compress", + "decompress", + "wait_end_of_query", + "buffer_size", + "role", + } + optional_transport_settings: set[str] = { + "send_progress_in_http_headers", + "http_headers_progress_interval_ms", + "enable_http_compression", + } + + def __init__( + self, + chdb_path: str = ":memory:", + chdb_options: dict[str, Any] | None = None, + database: str = "__default__", + settings: dict[str, Any] | None = None, + query_limit: int = 0, + query_retries: int = 0, + tz_source: TzSource | None = None, + tz_mode: TzMode | None = None, + show_clickhouse_errors: bool | None = None, + **ignored, + ): + if sys.platform.startswith("win"): + raise NotSupportedError("chdb backend is not supported on Windows") + + try: + import chdb + except ImportError as ex: + raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex + + for key in ignored: + if key in _HTTP_ONLY_KWARGS: + continue + logger.warning("ChdbClient: ignoring unrecognized kwarg %r", key) + + self._chdb_path = chdb_path or ":memory:" + self._chdb_options = dict(chdb_options) if chdb_options else {} + self._connection_string = _build_conn_string(self._chdb_path, self._chdb_options) + self._chdb_module = chdb + self._conn = chdb.connect(self._connection_string) + self._lock = threading.Lock() + self._closed = False + self._client_settings: dict[str, str] = {} + self._initial_settings = dict(settings or {}) + self._read_format = "Native" + self._write_format = "Native" + self._transform = NativeTransform() + self._integration_libs: set[str] = set() + self.uri = f"chdb://{self._chdb_path}" + self.write_compression = None + self.compression = None + + # coerce_int handles None-or-string flexibility + super().__init__( + database=database, + uri=self.uri, + query_limit=coerce_int(query_limit), + query_retries=coerce_int(query_retries), + server_host_name=None, + tz_source=tz_source, + tz_mode=tz_mode, + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=True, + ) + + for k, v in self._initial_settings.items(): + self.set_client_setting(k, v) + + if self.database: + self._exec_raw_query(f"USE {quote_identifier(self.database)}") + + logger.info( + "ChdbClient connected: chdb=%s, server_version=%s, path=%s", + getattr(chdb, "__version__", "?"), + self.server_version, + self._chdb_path, + ) + + # ---- helpers ------------------------------------------------------- + + @property + def chdb_connection(self): + """Underlying chdb connection. Escape hatch for advanced users.""" + return self._conn + + def _ensure_open(self) -> None: + if self._closed: + raise ProgrammingError("ChdbClient is closed") from None + + def _filter_per_call_settings(self, settings: dict[str, Any] | None) -> dict[str, str]: + """Validate per-call settings and drop transport-only ones.""" + out: dict[str, str] = {} + if not settings: + return out + invalid_action = common.get_setting("invalid_setting_action") + for k, v in settings.items(): + str_v = self._validate_setting(k, v, invalid_action) + if str_v is None: + continue + if k in self.valid_transport_settings or k in self.optional_transport_settings: + continue + out[k] = str_v + return out + + def _append_settings_clause(self, sql: str, settings: dict[str, str]) -> str: + if not settings: + return sql + extras = ", ".join(f"{k} = {v}" for k, v in settings.items()) + if " SETTINGS " in sql.upper(): + return f"{sql}, {extras}" + return f"{sql} SETTINGS {extras}" + + def _persist_setting(self, key: str, value: str) -> None: + """Apply a setting to the underlying chdb session via SET.""" + try: + with self._lock: + self._conn.query(f"SET {key} = {value}", "TabSeparated") + except Exception as ex: # noqa: BLE001 + logger.debug("Failed to apply SET %s=%s to chdb session: %s", key, value, ex) + + def _exec_raw_query(self, sql: str, fmt: str = "Native") -> bytes: + """Run a query against chdb under the per-client lock and return raw bytes.""" + self._ensure_open() + with self._lock: + try: + result = self._conn.query(sql, fmt) + except Exception as ex: # noqa: BLE001 + raise self._wrap_exception(ex) from ex + return result.bytes() if hasattr(result, "bytes") else bytes(result) + + def _wrap_exception(self, ex: Exception) -> Exception: + message = _format_error_message(str(ex)) + if not self.show_clickhouse_errors: + message = "ClickHouse error" + return DatabaseError(message) + + def _format_for_command(self) -> str: + return "TabSeparated" + + # ---- abstract method implementations ------------------------------- + + def set_client_setting(self, key: str, value: Any) -> None: + str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) + if str_value is None: + return + self._client_settings[key] = str_value + if key in self.valid_transport_settings or key in self.optional_transport_settings: + return + self._persist_setting(key, str_value) + + def get_client_setting(self, key: str) -> str | None: + return self._client_settings.get(key) + + def set_access_token(self, access_token: str) -> None: + # chdb has no auth concept; accept silently for HTTP-mode drop-in compatibility. + return None + + def _query_with_context(self, context: QueryContext) -> QueryResult: + self._ensure_open() + if context.external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + # chdb's Native output does not include the 8-byte block_info prefix that the + # HTTP server emits when client_protocol_version is set. + context.block_info = False + final_query = self._prep_query(context) + if isinstance(final_query, bytes): + final_query = final_query.decode() + if context.is_insert: + # INSERT ... VALUES carries its data inline and has no result block to parse; + # appending `FORMAT Native` to a VALUES statement is a syntax error. + sql = self._append_settings_clause(final_query, self._filter_per_call_settings(context.settings)) + self._exec_raw_query(sql, "TabSeparated") + return QueryResult([]) + sql = f"{final_query}\n FORMAT Native" + sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) + data = self._exec_raw_query(sql, "Native") + byte_source = RespBuffCls(_BytesSource(data)) + query_result = self._transform.parse_response(byte_source, context) + query_result.summary = {} + return query_result + + def raw_query( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> bytes: + if external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + final_query, _ = bind_query(query, parameters, self.server_tz) + if isinstance(final_query, bytes): + final_query = final_query.decode() + if fmt: + final_query = f"{final_query}\n FORMAT {fmt}" + final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) + return self._exec_raw_query(final_query, fmt or "Native") + + def raw_stream( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> io.IOBase: + if external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + final_query, _ = bind_query(query, parameters, self.server_tz) + if isinstance(final_query, bytes): + final_query = final_query.decode() + if fmt: + final_query = f"{final_query}\n FORMAT {fmt}" + final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) + self._ensure_open() + # Acquire the lock for the lifetime of the streaming read so concurrent + # callers don't interleave queries on the same chdb connection. + self._lock.acquire() + try: + streaming = self._conn.send_query(final_query, fmt or "Native") + except Exception as ex: # noqa: BLE001 + self._lock.release() + raise self._wrap_exception(ex) from ex + return _ChdbStreamFile(streaming, self._lock) + + def command( + self, + cmd: str, + parameters: Sequence | dict[str, Any] | None = None, + data: str | bytes | None = None, + settings: dict[str, Any] | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> str | int | Sequence[str] | QuerySummary: + if external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + cmd, _ = bind_query(cmd, parameters, self.server_tz) + if isinstance(cmd, bytes): + cmd = cmd.decode() + if data is not None: + if isinstance(data, bytes): + data_str = data.decode() + else: + data_str = data + cmd = f"{cmd}\n{data_str}" + per_call = self._filter_per_call_settings(settings) + # ClickHouse DDL doesn't accept a SETTINGS clause; apply per-call settings to the + # chdb session via SET before running the command. Client-level settings are + # already applied at set time, so no extra work needed for them. + for k, v in per_call.items(): + self._persist_setting(k, v) + body = self._exec_raw_query(cmd, self._format_for_command()) + if not body: + return QuerySummary({}) + try: + text = body.decode() + except UnicodeDecodeError: + return str(body) + # Match HTTP client semantics: strip trailing newline, split by tab, single + # token tries to coerce to int. + if text.endswith("\n"): + text = text[:-1] + result = text.split("\t") + if len(result) == 1: + try: + return int(result[0]) + except ValueError: + return result[0] + return result + + def ping(self) -> bool: + try: + self._exec_raw_query("SELECT 1", "TabSeparated") + return True + except Exception: # noqa: BLE001 + logger.debug("chdb ping failed", exc_info=True) + return False + + def data_insert(self, context: InsertContext) -> QuerySummary: + if context.empty: + return QuerySummary() + + # DataFrame fast path: hand the DataFrame to chdb directly via the + # `Python(name)` table function. This skips serialization and disk I/O. + if self._can_use_dataframe_fast_path(context): + df = context.data + return self._insert_dataframe_fast(context, df) + + return self._insert_via_infile(context) + + def raw_insert( + self, + table: str | None = None, + column_names: Sequence[str] | None = None, + insert_block: str | bytes | Generator[bytes, None, None] | BinaryIO | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + compression: str | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + if insert_block is None or not table: + raise ProgrammingError("raw_insert requires a table and insert_block") + if compression: + raise NotSupportedError("compression is not supported for raw_insert in chdb mode. Provide uncompressed bytes.") + + fmt = fmt or self._write_format + cols = "" + if column_names: + cols = f" ({', '.join(quote_identifier(c) for c in column_names)})" + + # Drain insert_block to a temp file, then INSERT FROM INFILE. + tmp = tempfile.NamedTemporaryFile(suffix=f".{fmt.lower()}", delete=False) + try: + try: + if isinstance(insert_block, (bytes, bytearray)): + tmp.write(bytes(insert_block)) + elif isinstance(insert_block, str): + tmp.write(insert_block.encode()) + elif hasattr(insert_block, "read"): + while True: + chunk = insert_block.read(1 << 20) + if not chunk: + break + tmp.write(chunk if isinstance(chunk, (bytes, bytearray)) else chunk.encode()) + else: + for chunk in insert_block: + tmp.write(chunk if isinstance(chunk, (bytes, bytearray)) else chunk.encode()) + finally: + tmp.close() + + sql = f"INSERT INTO {table}{cols} FROM INFILE '{tmp.name}' FORMAT {fmt}" + sql = self._append_settings_clause(sql, self._filter_per_call_settings(settings)) + self._exec_raw_query(sql, "TabSeparated") + return QuerySummary({}) + finally: + try: + os.unlink(tmp.name) + except OSError: + pass + + def close(self) -> None: + if self._closed: + return + try: + with self._lock: + self._conn.close() + except Exception: # noqa: BLE001 + logger.debug("Error closing chdb connection", exc_info=True) + self._closed = True + + def close_connections(self) -> None: + # chdb only has a single embedded connection per client. + self.close() + + # ---- insert implementations ---------------------------------------- + + def _can_use_dataframe_fast_path(self, context: InsertContext) -> bool: + if options.pd is None: + return False + data = context.data + if not isinstance(data, options.pd.DataFrame): + return False + return True + + def _insert_dataframe_fast(self, context: InsertContext, df) -> QuerySummary: + # Reorder/rename DataFrame columns to match the target schema so the + # `SELECT * FROM Python(df)` projection lines up with the destination. + try: + chdb_df = df[list(context.column_names)] if list(df.columns) != list(context.column_names) else df + except KeyError as ex: + raise ProgrammingError(f"DataFrame is missing target column {ex}") from None + + ref_name = _register_chdb_object(chdb_df) + try: + sql = ( + f"INSERT INTO {context.table} ({', '.join(quote_identifier(c) for c in context.column_names)}) " + f"SELECT * FROM Python({ref_name})" + ) + sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) + self._exec_raw_query(sql, "TabSeparated") + finally: + _unregister_chdb_object(ref_name) + context.data = None + return QuerySummary({}) + + def _insert_via_infile(self, context: InsertContext) -> QuerySummary: + tmp = tempfile.NamedTemporaryFile(suffix=".native", delete=False) + try: + try: + first_chunk = True + # NativeTransform.build_insert prepends an `INSERT INTO ... FORMAT Native\n` + # statement to the first chunk for the HTTP request body. We're going to + # write only the Native bytes to a file and INSERT FROM INFILE, so the + # prefix must be skipped. + for chunk in self._transform.build_insert(context): + if context.insert_exception is not None: + ex = context.insert_exception + context.insert_exception = None + raise ex + if first_chunk: + nl = chunk.find(b"\n") + if nl >= 0: + chunk = chunk[nl + 1 :] + first_chunk = False + tmp.write(chunk) + finally: + tmp.close() + + cols = ", ".join(quote_identifier(c) for c in context.column_names) + sql = f"INSERT INTO {context.table} ({cols}) FROM INFILE '{tmp.name}' FORMAT Native" + sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) + self._exec_raw_query(sql, "TabSeparated") + return QuerySummary({}) + finally: + try: + os.unlink(tmp.name) + except OSError: + pass + context.data = None + + # ---- integration tagging ------------------------------------------ + + def _add_integration_tag(self, name: str) -> None: + # No User-Agent header to update for in-process chdb; just record for + # potential future use. + self._integration_libs.add(name) + + +class _ChdbStreamFile(io.RawIOBase): + """ + File-like adapter wrapping chdb's StreamingResult iterator so callers in + clickhouse-connect (which expect an io.IOBase / aiohttp-style stream) can + iterate bytes block-by-block. + + Holds a per-client lock for its lifetime so the chdb connection is not used + concurrently by another caller while a stream is in flight. + """ + + def __init__(self, streaming_result, lock: threading.Lock): + super().__init__() + self._sr = streaming_result + self._lock = lock + self._buf = b"" + self._eof = False + self._closed_flag = False + + def readable(self) -> bool: + return True + + def _pull(self) -> bytes: + while True: + try: + chunk = next(self._sr) + except StopIteration: + self._eof = True + return b"" + payload = chunk.bytes() if hasattr(chunk, "bytes") else bytes(chunk) + if payload: + return payload + + def read(self, size: int | None = -1) -> bytes: + if self._closed_flag: + return b"" + if size is None or size < 0: + parts = [self._buf] + self._buf = b"" + while not self._eof: + chunk = self._pull() + if not chunk: + break + parts.append(chunk) + return b"".join(parts) + while len(self._buf) < size and not self._eof: + chunk = self._pull() + if not chunk: + break + self._buf += chunk + if not self._buf: + return b"" + out = self._buf[:size] + self._buf = self._buf[size:] + return out + + def readinto(self, buf) -> int: + data = self.read(len(buf)) + n = len(data) + if n: + buf[:n] = data + return n + + def close(self) -> None: + if self._closed_flag: + return + self._closed_flag = True + try: + close = getattr(self._sr, "close", None) + if close: + close() + except Exception: # noqa: BLE001 + logger.debug("Error closing chdb StreamingResult", exc_info=True) + finally: + try: + self._lock.release() + except RuntimeError: + pass + super().close() diff --git a/setup.py b/setup.py index 856e0a19..8fafa0d1 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ def run_setup(try_c: bool = True): "tzlocal": ["tzlocal>=4.0"], "tzdata": ["tzdata"], "async": ["aiohttp>=3.8.0"], + "chdb": ['chdb>=4.1.7; sys_platform != "win32"'], }, tests_require=["pytest"], entry_points={ diff --git a/tests/test_bare_install.py b/tests/test_bare_install.py index 59a29c09..02fff136 100644 --- a/tests/test_bare_install.py +++ b/tests/test_bare_install.py @@ -1,8 +1,29 @@ +import importlib.util + import clickhouse_connect +def test_chdb_backend_missing_dep_raises_clean_error(): + """Without chdb installed, requesting interface='chdb' must surface a clean ImportError. + + The bare install CI job deliberately omits the chdb extra, so this verifies the friendly + error path. If chdb happens to be importable (local dev), this assertion is skipped. + """ + if importlib.util.find_spec("chdb") is not None: + print("chdb is installed; skipping missing-dep error path check") + return + try: + clickhouse_connect.get_client(interface="chdb") + except ImportError as ex: + assert "chdb" in str(ex), f"expected chdb in error message, got: {ex}" + return + raise AssertionError("Expected ImportError when chdb is not installed") + + def test_bare_install(): """Bare install test to validate the package works with only core dependencies""" + test_chdb_backend_missing_dep_raises_clean_error() + client = clickhouse_connect.get_client() ver = client.command("SELECT version()") @@ -17,3 +38,8 @@ def test_bare_install(): res = client.query("SELECT * FROM _bare_install_test ORDER BY id") assert res.result_rows == [(1, "a"), (2, "b")], f"unexpected: {res.result_rows}" client.command("DROP TABLE _bare_install_test") + + +if __name__ == "__main__": + test_chdb_backend_missing_dep_raises_clean_error() + test_bare_install() diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index a2050b4a..9ec0edee 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -24,4 +24,5 @@ lz4>=4.4.5; python_version >= "3.14" pyjwt[crypto]==2.10.1 pre-commit==4.3.0 ruff==0.15.8 +chdb>=4.1.7; sys_platform != "win32" \ No newline at end of file diff --git a/tests/unit_tests/test_driver/test_chdbclient.py b/tests/unit_tests/test_driver/test_chdbclient.py new file mode 100644 index 00000000..f8188652 --- /dev/null +++ b/tests/unit_tests/test_driver/test_chdbclient.py @@ -0,0 +1,320 @@ +""" +Unit tests for the in-process chdb client backend. + +These tests do not require a ClickHouse server — chdb is the embedded engine. +Skipped automatically if `chdb` is not installable (e.g. Windows or bare +install). +""" + +from __future__ import annotations + +import asyncio +from datetime import date, datetime +from decimal import Decimal + +import pytest + +chdb = pytest.importorskip("chdb") + +import clickhouse_connect # noqa: E402 +from clickhouse_connect.driver.exceptions import ( # noqa: E402 + DatabaseError, + NotSupportedError, +) + + +@pytest.fixture +def client(): + c = clickhouse_connect.get_client(interface="chdb") + yield c + c.close() + + +@pytest.fixture +def async_client(): + return clickhouse_connect.get_async_client + + +# ---- basic protocol ---- + + +def test_ping(client): + assert client.ping() is True + + +def test_server_version_populated(client): + assert client.server_version + assert client.server_version.split(".")[0].isdigit() + + +def test_uri_shape(): + c = clickhouse_connect.get_client(interface="chdb", chdb_path=":memory:") + try: + assert c.uri.startswith("chdb://") + finally: + c.close() + + +def test_chdb_connection_escape_hatch_exposed(client): + assert client.chdb_connection is not None + + +# ---- query / command ---- + + +def test_command_returns_scalar(client): + assert client.command("SELECT 13") == 13 + assert client.command("SELECT 'user_1'") == "user_1" + + +def test_command_returns_tuple_for_multiple_columns(client): + result = client.command("SELECT 79, 'user_2'") + assert result == ["79", "user_2"] + + +def test_query_primitives(client): + r = client.query( + "SELECT toInt32(13) AS i, toString('user_1') AS s, toFloat64(3.14) AS f", + ) + assert r.column_names == ("i", "s", "f") + assert r.result_rows == [(13, "user_1", 3.14)] + + +def test_query_nullable_and_low_cardinality(client): + r = client.query("SELECT CAST(NULL AS Nullable(Int64)) AS n, CAST('user_2' AS LowCardinality(String)) AS lc") + row = r.result_rows[0] + assert row[0] is None + assert row[1] == "user_2" + + +def test_query_dates_decimals(client): + r = client.query("SELECT toDate('2026-05-19') AS d, toDateTime('2026-05-19 10:30:00', 'UTC') AS dt, toDecimal64(123.456, 3) AS dec") + d, dt, dec = r.result_rows[0] + assert d == date(2026, 5, 19) + assert dt == datetime(2026, 5, 19, 10, 30, 0) + assert dec == Decimal("123.456") + + +def test_query_array_and_map(client): + r = client.query("SELECT [1, 2, 3]::Array(UInt32) AS arr, map('user_1', 13, 'user_2', 79) AS m") + arr, m = r.result_rows[0] + assert list(arr) == [1, 2, 3] + assert m == {"user_1": 13, "user_2": 79} + + +def test_query_multi_row(client): + r = client.query("SELECT number FROM numbers(5)") + assert [row[0] for row in r.result_rows] == [0, 1, 2, 3, 4] + + +def test_query_empty(client): + r = client.query("SELECT 1 WHERE 0") + assert r.result_rows == [] + + +def test_raw_query_pass_through(client): + body = client.raw_query("SELECT 13 AS x", fmt="TabSeparated") + assert body == b"13\n" + + +# ---- insert paths ---- + + +def test_insert_row_data(client): + client.command("CREATE TABLE row_insert_test (id UInt32, name String) ENGINE = Memory") + client.insert( + "row_insert_test", + [[13, "user_1"], [79, "user_2"]], + column_names=["id", "name"], + ) + r = client.query("SELECT id, name FROM row_insert_test ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_insert_dataframe_fast_path(client): + pd = pytest.importorskip("pandas") + client.command("CREATE TABLE df_insert_test (id UInt32, v Float64) ENGINE = Memory") + df = pd.DataFrame({"id": [13, 79, 103], "v": [1.5, 2.5, 3.5]}) + client.insert_df("df_insert_test", df) + r = client.query("SELECT id, v FROM df_insert_test ORDER BY id") + assert r.result_rows == [(13, 1.5), (79, 2.5), (103, 3.5)] + + +def test_insert_dataframe_reordered_columns(client): + pd = pytest.importorskip("pandas") + client.command("CREATE TABLE df_reorder (id UInt32, v Float64) ENGINE = Memory") + df = pd.DataFrame({"v": [9.5, 10.5], "id": [13, 79]}) # reversed + client.insert_df("df_reorder", df) + r = client.query("SELECT id, v FROM df_reorder ORDER BY id") + assert r.result_rows == [(13, 9.5), (79, 10.5)] + + +def test_raw_insert_bytes_round_trip(client): + client.command("CREATE TABLE raw_insert_test (id UInt32, v String) ENGINE = Memory") + csv = b"13,user_1\n79,user_2\n" + client.raw_insert("raw_insert_test", insert_block=csv, fmt="CSV") + r = client.query("SELECT id, v FROM raw_insert_test ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +# ---- session semantics ---- + + +def test_session_persistence_within_client(client): + client.command("CREATE TEMPORARY TABLE temp_persist (id Int32)") + client.command("INSERT INTO temp_persist VALUES (13), (79)") + r = client.query("SELECT count() FROM temp_persist") + assert r.result_rows[0][0] == 2 + + +def test_set_client_setting_persists(client): + client.set_client_setting("max_block_size", 1000) + assert client.get_client_setting("max_block_size") == "1000" + + +# ---- streaming ---- + + +def test_query_row_block_stream(client): + with client.query_row_block_stream("SELECT number FROM numbers(50) SETTINGS max_block_size = 10") as stream: + blocks = list(stream) + assert sum(len(b) for b in blocks) == 50 + + +def test_raw_stream_iterates(client): + stream = client.raw_stream("SELECT number FROM numbers(5)", fmt="CSV") + try: + data = stream.read() + finally: + stream.close() + assert data.startswith(b"0\n") + + +# ---- error mapping ---- + + +def test_unknown_function_maps_to_database_error(client): + with pytest.raises(DatabaseError) as ex_info: + client.query("SELECT bad_function()") + assert "UNKNOWN_FUNCTION" in str(ex_info.value) or "bad_function" in str(ex_info.value) + + +def test_external_data_not_supported(client): + from clickhouse_connect.driver.external import ExternalData + + ext = ExternalData(file_name="x.csv", data=b"1\n2\n", fmt="CSV", structure="id UInt32") + with pytest.raises(NotSupportedError): + client.query("SELECT * FROM x", external_data=ext) + + +# ---- HTTP-only kwargs accepted silently ---- + + +def test_http_only_kwargs_silently_ignored(): + c = clickhouse_connect.get_client( + interface="chdb", + username="default", + password="ignored", + compress=True, + connect_timeout=10, + verify=True, + http_proxy="http://localhost:3128", + ) + try: + assert c.ping() is True + finally: + c.close() + + +def test_set_access_token_silent_noop(client): + client.set_access_token("not-a-real-token") # must not raise + + +# ---- DBAPI on top of chdb ---- + + +def test_dbapi_cursor_round_trip(): + import clickhouse_connect.dbapi as dbapi + + conn = dbapi.connect(interface="chdb") + try: + cur = conn.cursor() + try: + cur.execute("CREATE TABLE dba_round_trip (id UInt32, name String) ENGINE = Memory") + cur.execute("INSERT INTO dba_round_trip VALUES (13, 'user_1'), (79, 'user_2')") + cur.execute("SELECT id, name FROM dba_round_trip ORDER BY id") + rows = cur.fetchall() + assert rows == [(13, "user_1"), (79, "user_2")] + assert [c[0] for c in cur.description] == ["id", "name"] + finally: + cur.close() + finally: + conn.close() + + +# ---- async client ---- + + +def test_async_client_basic_flow(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + assert await c.ping() is True + r = await c.query("SELECT 13 AS x") + assert r.result_rows == [(13,)] + await c.command("CREATE TABLE async_smoke (id UInt32) ENGINE = Memory") + await c.insert("async_smoke", [[13], [79]], column_names=["id"]) + r = await c.query("SELECT count() FROM async_smoke") + assert r.result_rows[0][0] == 2 + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_client_gather_serializes_without_error(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + results = await asyncio.gather( + c.query("SELECT 13"), + c.query("SELECT 79"), + c.query("SELECT 103"), + ) + values = [r.result_rows[0][0] for r in results] + assert sorted(values) == [13, 79, 103] + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_dataframe_fast_path(): + pd = pytest.importorskip("pandas") + + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + await c.command("CREATE TABLE async_df (id UInt32, v Float64) ENGINE = Memory") + df = pd.DataFrame({"id": [13, 79], "v": [1.5, 2.5]}) + await c.insert_df("async_df", df) + out = await c.query_df("SELECT id, v FROM async_df ORDER BY id") + assert list(out["id"]) == [13, 79] + assert list(out["v"]) == [1.5, 2.5] + finally: + await c.close() + + asyncio.run(run()) + + +# ---- factory / dispatch ---- + + +def test_factory_dispatches_on_interface(): + c = clickhouse_connect.get_client(interface="chdb") + try: + from clickhouse_connect.driver.chdbclient import ChdbClient + + assert isinstance(c, ChdbClient) + finally: + c.close() From 0466ca5c788f89af2c5c36ae877f1540f399826e Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 02:16:37 +0000 Subject: [PATCH 2/8] Drop sys_platform marker on chdb extra so Windows fails at install time --- setup.py | 2 +- tests/test_requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8fafa0d1..9d5cd175 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ def run_setup(try_c: bool = True): "tzlocal": ["tzlocal>=4.0"], "tzdata": ["tzdata"], "async": ["aiohttp>=3.8.0"], - "chdb": ['chdb>=4.1.7; sys_platform != "win32"'], + "chdb": ["chdb>=4.1.7"], }, tests_require=["pytest"], entry_points={ diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index 9ec0edee..18490143 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -24,5 +24,5 @@ lz4>=4.4.5; python_version >= "3.14" pyjwt[crypto]==2.10.1 pre-commit==4.3.0 ruff==0.15.8 -chdb>=4.1.7; sys_platform != "win32" +chdb>=4.1.7 \ No newline at end of file From 1230dbffe5df9236a213c1ddc20cd715a28f1099 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 08:21:21 +0000 Subject: [PATCH 3/8] Drop chdb DataFrame fast path and clean up dead code --- clickhouse_connect/driver/__init__.py | 4 +- clickhouse_connect/driver/chdbasync.py | 23 ++---- clickhouse_connect/driver/chdbclient.py | 80 ------------------- tests/test_bare_install.py | 8 +- .../unit_tests/test_driver/test_chdbclient.py | 4 +- 5 files changed, 10 insertions(+), 109 deletions(-) diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index cb42bcd4..626f4280 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -343,7 +343,7 @@ def _create_chdb_client( try: from clickhouse_connect.driver.chdbclient import ChdbClient except ImportError as ex: - if ex.name == "chdb" or (ex.name and ex.name.startswith("chdb")): + if ex.name and ex.name.startswith("chdb"): raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex raise @@ -370,7 +370,7 @@ def _create_chdb_async_client( try: from clickhouse_connect.driver.chdbasync import AsyncChdbClient except ImportError as ex: - if ex.name == "chdb" or (ex.name and ex.name.startswith("chdb")): + if ex.name and ex.name.startswith("chdb"): raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex raise diff --git a/clickhouse_connect/driver/chdbasync.py b/clickhouse_connect/driver/chdbasync.py index d4998cc5..8dc9d58f 100644 --- a/clickhouse_connect/driver/chdbasync.py +++ b/clickhouse_connect/driver/chdbasync.py @@ -12,7 +12,6 @@ import asyncio import io -import logging from collections.abc import Generator, Iterable, Sequence from datetime import tzinfo from typing import TYPE_CHECKING, Any, BinaryIO @@ -32,8 +31,6 @@ import polars import pyarrow -logger = logging.getLogger(__name__) - class AsyncChdbClient(Client): """ @@ -43,9 +40,6 @@ class AsyncChdbClient(Client): directly. """ - valid_transport_settings: set[str] = ChdbClient.valid_transport_settings - optional_transport_settings: set[str] = ChdbClient.optional_transport_settings - def __init__(self, sync: ChdbClient): self._sync = sync # Mirror attributes commonly read off the client object so user code that @@ -75,9 +69,7 @@ def chdb_connection(self): async def _run(self, func, *args, **kwargs): loop = asyncio.get_running_loop() - if kwargs: - return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) - return await loop.run_in_executor(None, func, *args) + return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) # ---- sync passthroughs (no I/O) ---- @@ -315,13 +307,8 @@ async def __aexit__(self, exc_type, exc, tb): await self.close() return False - # Some helper methods on Client (like create_insert_context, create_query_context) - # do synchronous local work and call self.query/self.command for schema lookup. We - # can't await inside a sync method, so users should normally rely on insert/query - # which we already async-wrap. - - def create_insert_context(self, *args, **kwargs): - return self._sync.create_insert_context(*args, **kwargs) + async def create_insert_context(self, *args, **kwargs) -> InsertContext: # type: ignore[override] + return await self._run(lambda: self._sync.create_insert_context(*args, **kwargs)) - def create_query_context(self, *args, **kwargs): - return self._sync.create_query_context(*args, **kwargs) + async def create_query_context(self, *args, **kwargs) -> QueryContext: # type: ignore[override] + return await self._run(lambda: self._sync.create_query_context(*args, **kwargs)) diff --git a/clickhouse_connect/driver/chdbclient.py b/clickhouse_connect/driver/chdbclient.py index 9e84c25c..d933a10f 100644 --- a/clickhouse_connect/driver/chdbclient.py +++ b/clickhouse_connect/driver/chdbclient.py @@ -15,12 +15,10 @@ import sys import tempfile import threading -import uuid from collections.abc import Generator, Sequence from typing import TYPE_CHECKING, Any, BinaryIO from clickhouse_connect import common -from clickhouse_connect.driver import options from clickhouse_connect.driver.binding import bind_query, quote_identifier from clickhouse_connect.driver.client import Client from clickhouse_connect.driver.common import coerce_int @@ -97,48 +95,6 @@ def close(self): return None -class _ChunkIterSource: - """Source backed by an iterator of byte chunks, used for streaming reads.""" - - __slots__ = ("_chunks", "last_message", "exception_tag") - - def __init__(self, chunks): - self._chunks = iter(chunks) - self.last_message = None - self.exception_tag = None - - @property - def gen(self): - return self._chunks - - def close(self): - try: - close = getattr(self._chunks, "close", None) - if close: - close() - except Exception: # noqa: BLE001 - pass - - -# Module globals used to expose user-provided Python objects (DataFrames, PyArrow -# tables) to chdb's `Python(name)` table function. chdb walks frames and module -# globals looking for the bare name passed to `Python(...)`, so we register -# objects under a uuid-suffixed name and clean up afterwards. -_chdb_ref_lock = threading.Lock() - - -def _register_chdb_object(obj) -> str: - name = f"_chdb_ref_{uuid.uuid4().hex}" - with _chdb_ref_lock: - globals()[name] = obj - return name - - -def _unregister_chdb_object(name: str) -> None: - with _chdb_ref_lock: - globals().pop(name, None) - - def _format_error_message(message: str) -> str: """Extract a clean ClickHouse exception message from a chdb error string.""" if not message: @@ -468,13 +424,6 @@ def ping(self) -> bool: def data_insert(self, context: InsertContext) -> QuerySummary: if context.empty: return QuerySummary() - - # DataFrame fast path: hand the DataFrame to chdb directly via the - # `Python(name)` table function. This skips serialization and disk I/O. - if self._can_use_dataframe_fast_path(context): - df = context.data - return self._insert_dataframe_fast(context, df) - return self._insert_via_infile(context) def raw_insert( @@ -543,35 +492,6 @@ def close_connections(self) -> None: # ---- insert implementations ---------------------------------------- - def _can_use_dataframe_fast_path(self, context: InsertContext) -> bool: - if options.pd is None: - return False - data = context.data - if not isinstance(data, options.pd.DataFrame): - return False - return True - - def _insert_dataframe_fast(self, context: InsertContext, df) -> QuerySummary: - # Reorder/rename DataFrame columns to match the target schema so the - # `SELECT * FROM Python(df)` projection lines up with the destination. - try: - chdb_df = df[list(context.column_names)] if list(df.columns) != list(context.column_names) else df - except KeyError as ex: - raise ProgrammingError(f"DataFrame is missing target column {ex}") from None - - ref_name = _register_chdb_object(chdb_df) - try: - sql = ( - f"INSERT INTO {context.table} ({', '.join(quote_identifier(c) for c in context.column_names)}) " - f"SELECT * FROM Python({ref_name})" - ) - sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) - self._exec_raw_query(sql, "TabSeparated") - finally: - _unregister_chdb_object(ref_name) - context.data = None - return QuerySummary({}) - def _insert_via_infile(self, context: InsertContext) -> QuerySummary: tmp = tempfile.NamedTemporaryFile(suffix=".native", delete=False) try: diff --git a/tests/test_bare_install.py b/tests/test_bare_install.py index 02fff136..4e3f6bf4 100644 --- a/tests/test_bare_install.py +++ b/tests/test_bare_install.py @@ -4,11 +4,7 @@ def test_chdb_backend_missing_dep_raises_clean_error(): - """Without chdb installed, requesting interface='chdb' must surface a clean ImportError. - - The bare install CI job deliberately omits the chdb extra, so this verifies the friendly - error path. If chdb happens to be importable (local dev), this assertion is skipped. - """ + """Without chdb installed, interface='chdb' must raise a clean ImportError.""" if importlib.util.find_spec("chdb") is not None: print("chdb is installed; skipping missing-dep error path check") return @@ -22,8 +18,6 @@ def test_chdb_backend_missing_dep_raises_clean_error(): def test_bare_install(): """Bare install test to validate the package works with only core dependencies""" - test_chdb_backend_missing_dep_raises_clean_error() - client = clickhouse_connect.get_client() ver = client.command("SELECT version()") diff --git a/tests/unit_tests/test_driver/test_chdbclient.py b/tests/unit_tests/test_driver/test_chdbclient.py index f8188652..d99f5be2 100644 --- a/tests/unit_tests/test_driver/test_chdbclient.py +++ b/tests/unit_tests/test_driver/test_chdbclient.py @@ -131,7 +131,7 @@ def test_insert_row_data(client): assert r.result_rows == [(13, "user_1"), (79, "user_2")] -def test_insert_dataframe_fast_path(client): +def test_insert_dataframe(client): pd = pytest.importorskip("pandas") client.command("CREATE TABLE df_insert_test (id UInt32, v Float64) ENGINE = Memory") df = pd.DataFrame({"id": [13, 79, 103], "v": [1.5, 2.5, 3.5]}) @@ -289,7 +289,7 @@ async def run(): asyncio.run(run()) -def test_async_dataframe_fast_path(): +def test_async_dataframe_insert(): pd = pytest.importorskip("pandas") async def run(): From e70649f8301998ba52cf9ac46cc466f40a507860 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 09:47:12 +0000 Subject: [PATCH 4/8] Restore chdb session settings after command() per-call overrides --- clickhouse_connect/driver/chdbclient.py | 114 ++++++++---------- .../unit_tests/test_driver/test_chdbclient.py | 26 ++++ 2 files changed, 79 insertions(+), 61 deletions(-) diff --git a/clickhouse_connect/driver/chdbclient.py b/clickhouse_connect/driver/chdbclient.py index d933a10f..dba4c025 100644 --- a/clickhouse_connect/driver/chdbclient.py +++ b/clickhouse_connect/driver/chdbclient.py @@ -39,36 +39,6 @@ logger = logging.getLogger(__name__) -# HTTP-only kwargs accepted (and ignored) so users can switch interface without -# editing the rest of their connection config. -_HTTP_ONLY_KWARGS = frozenset( - { - "compress", - "compression", - "connect_timeout", - "send_receive_timeout", - "client_name", - "verify", - "ca_cert", - "client_cert", - "client_cert_key", - "session_id", - "pool_mgr", - "http_proxy", - "https_proxy", - "tls_mode", - "proxy_path", - "form_encode_query_params", - "rename_response_column", - "autogenerate_session_id", - "autogenerate_query_id", - "connector_limit", - "connector_limit_per_host", - "keepalive_timeout", - "server_host_name", - } -) - class _BytesSource: """ @@ -106,20 +76,14 @@ def _format_error_message(message: str) -> str: def _build_conn_string(chdb_path: str, chdb_options: dict[str, Any] | None) -> str: - if not chdb_path or chdb_path in (":memory:", "memory"): - path = ":memory:" - elif chdb_path.startswith("file:"): - return chdb_path - else: - path = chdb_path + path = chdb_path or ":memory:" if not chdb_options: return path from urllib.parse import urlencode query = urlencode({k: str(v) for k, v in chdb_options.items()}) - if path == ":memory:": - return f"file::memory:?{query}" - return f"file:{path}?{query}" + sep = "&" if "?" in path else "?" + return f"{path}{sep}{query}" class ChdbClient(Client): @@ -139,8 +103,6 @@ class ChdbClient(Client): "wait_end_of_query", "buffer_size", "role", - } - optional_transport_settings: set[str] = { "send_progress_in_http_headers", "http_headers_progress_interval_ms", "enable_http_compression", @@ -150,10 +112,9 @@ def __init__( self, chdb_path: str = ":memory:", chdb_options: dict[str, Any] | None = None, - database: str = "__default__", + database: str | None = None, settings: dict[str, Any] | None = None, query_limit: int = 0, - query_retries: int = 0, tz_source: TzSource | None = None, tz_mode: TzMode | None = None, show_clickhouse_errors: bool | None = None, @@ -167,11 +128,6 @@ def __init__( except ImportError as ex: raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex - for key in ignored: - if key in _HTTP_ONLY_KWARGS: - continue - logger.warning("ChdbClient: ignoring unrecognized kwarg %r", key) - self._chdb_path = chdb_path or ":memory:" self._chdb_options = dict(chdb_options) if chdb_options else {} self._connection_string = _build_conn_string(self._chdb_path, self._chdb_options) @@ -194,7 +150,7 @@ def __init__( database=database, uri=self.uri, query_limit=coerce_int(query_limit), - query_retries=coerce_int(query_retries), + query_retries=0, server_host_name=None, tz_source=tz_source, tz_mode=tz_mode, @@ -236,7 +192,7 @@ def _filter_per_call_settings(self, settings: dict[str, Any] | None) -> dict[str str_v = self._validate_setting(k, v, invalid_action) if str_v is None: continue - if k in self.valid_transport_settings or k in self.optional_transport_settings: + if k in self.valid_transport_settings: continue out[k] = str_v return out @@ -257,6 +213,39 @@ def _persist_setting(self, key: str, value: str) -> None: except Exception as ex: # noqa: BLE001 logger.debug("Failed to apply SET %s=%s to chdb session: %s", key, value, ex) + def _snapshot_settings(self, keys: Sequence[str]) -> dict[str, tuple[str, bool]]: + """Read current value and 'changed' flag for each key from system.settings. + + Returns a dict: {name -> (value, was_explicitly_set)}. + """ + if not keys: + return {} + quoted = ", ".join(f"'{k}'" for k in keys) + body = self._exec_raw_query( + f"SELECT name, value, changed FROM system.settings WHERE name IN ({quoted})", + "TabSeparated", + ) + result: dict[str, tuple[str, bool]] = {} + if body: + for line in body.decode().rstrip("\n").split("\n"): + parts = line.split("\t") + if len(parts) == 3: + name, value, changed = parts + result[name] = (value, changed == "1") + return result + + def _restore_settings(self, snapshot: dict[str, tuple[str, bool]]) -> None: + """Restore settings to the state captured by `_snapshot_settings`.""" + for name, (value, was_changed) in snapshot.items(): + try: + if was_changed: + self._persist_setting(name, value) + else: + with self._lock: + self._conn.query(f"SET {name} = DEFAULT", "TabSeparated") + except Exception: # noqa: BLE001 + logger.debug("Failed to restore setting %s after command()", name, exc_info=True) + def _exec_raw_query(self, sql: str, fmt: str = "Native") -> bytes: """Run a query against chdb under the per-client lock and return raw bytes.""" self._ensure_open() @@ -283,7 +272,7 @@ def set_client_setting(self, key: str, value: Any) -> None: if str_value is None: return self._client_settings[key] = str_value - if key in self.valid_transport_settings or key in self.optional_transport_settings: + if key in self.valid_transport_settings: return self._persist_setting(key, str_value) @@ -333,8 +322,6 @@ def raw_query( final_query, _ = bind_query(query, parameters, self.server_tz) if isinstance(final_query, bytes): final_query = final_query.decode() - if fmt: - final_query = f"{final_query}\n FORMAT {fmt}" final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) return self._exec_raw_query(final_query, fmt or "Native") @@ -353,8 +340,6 @@ def raw_stream( final_query, _ = bind_query(query, parameters, self.server_tz) if isinstance(final_query, bytes): final_query = final_query.decode() - if fmt: - final_query = f"{final_query}\n FORMAT {fmt}" final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) self._ensure_open() # Acquire the lock for the lifetime of the streaming read so concurrent @@ -389,12 +374,19 @@ def command( data_str = data cmd = f"{cmd}\n{data_str}" per_call = self._filter_per_call_settings(settings) - # ClickHouse DDL doesn't accept a SETTINGS clause; apply per-call settings to the - # chdb session via SET before running the command. Client-level settings are - # already applied at set time, so no extra work needed for them. - for k, v in per_call.items(): - self._persist_setting(k, v) - body = self._exec_raw_query(cmd, self._format_for_command()) + # ClickHouse DDL doesn't accept a SETTINGS clause; apply per-call settings to + # the chdb session via SET before running the command, then restore them + # afterwards so they don't leak into the session. + snapshot: dict[str, tuple[str, bool]] = {} + if per_call: + snapshot = self._snapshot_settings(list(per_call.keys())) + for k, v in per_call.items(): + self._persist_setting(k, v) + try: + body = self._exec_raw_query(cmd, self._format_for_command()) + finally: + if snapshot: + self._restore_settings(snapshot) if not body: return QuerySummary({}) try: diff --git a/tests/unit_tests/test_driver/test_chdbclient.py b/tests/unit_tests/test_driver/test_chdbclient.py index d99f5be2..d45644e2 100644 --- a/tests/unit_tests/test_driver/test_chdbclient.py +++ b/tests/unit_tests/test_driver/test_chdbclient.py @@ -172,6 +172,32 @@ def test_set_client_setting_persists(client): assert client.get_client_setting("max_block_size") == "1000" +def _read_session_setting(client, name: str) -> str: + body = client.raw_query(f"SELECT value FROM system.settings WHERE name = '{name}'", fmt="TabSeparated") + return body.decode().strip() + + +def test_command_per_call_setting_does_not_leak(client): + before = _read_session_setting(client, "max_block_size") + client.command("SELECT 1", settings={"max_block_size": 13}) + after = _read_session_setting(client, "max_block_size") + assert after == before, f"max_block_size leaked: before={before!r} after={after!r}" + + +def test_command_per_call_setting_restored_on_error(client): + before = _read_session_setting(client, "max_block_size") + with pytest.raises(DatabaseError): + client.command("SELECT bad_function()", settings={"max_block_size": 13}) + after = _read_session_setting(client, "max_block_size") + assert after == before, f"max_block_size leaked after error: before={before!r} after={after!r}" + + +def test_command_restores_previously_set_value(client): + client.set_client_setting("max_block_size", 7) + client.command("SELECT 1", settings={"max_block_size": 13}) + assert _read_session_setting(client, "max_block_size") == "7" + + # ---- streaming ---- From b20f76c1d2c6367e6673e3826d63bbe1ee70d01f Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 09:50:48 +0000 Subject: [PATCH 5/8] Add CHANGELOG entry for chdb backend --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e72e349e..46cd2bcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## UNRELEASED +### New Features +- Add an in-process `chdb` backend. `clickhouse_connect.get_client(interface="chdb")` (and `get_async_client`) returns a client backed by the embedded ClickHouse engine via the `chdb` Python package, with no server required. Install with `pip install 'clickhouse-connect[chdb]'`. The existing `NativeTransform` byte parser is reused, so type handling, DB-API, and SQLAlchemy paths work unchanged. Linux and macOS only; HTTP-only kwargs (auth, TLS, proxy, retry) and `external_data` are no-ops or not supported. + ### Bug Fixes - Async client: `ca_cert="certifi"` shorthand now resolves to `certifi.where()`, matching the sync client. Previously the async path passed the literal string to `ssl_context.load_verify_locations`, producing `FileNotFoundError`. Closes [#742](https://github.com/ClickHouse/clickhouse-connect/issues/742) - Fix SQLAlchemy dialect rendering for `ILIKE` and `NOT ILIKE` expressions to use native ClickHouse syntax instead of the generic SQLAlchemy `lower(...) LIKE lower(...)` fallback. From 86ad2c9e880d0e2a1ad5897b15db0ffe7452c7ce Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 11:00:31 +0000 Subject: [PATCH 6/8] Wire chdb parameter binding and expand test coverage --- CHANGELOG.md | 2 +- clickhouse_connect/driver/__init__.py | 14 +- clickhouse_connect/driver/chdbclient.py | 33 +- setup.py | 2 +- tests/test_bare_install.py | 16 - tests/test_requirements.txt | 1 - .../unit_tests/test_driver/test_chdbclient.py | 386 +++++++++++++++++- 7 files changed, 404 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46cd2bcd..95f6c0ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ## UNRELEASED ### New Features -- Add an in-process `chdb` backend. `clickhouse_connect.get_client(interface="chdb")` (and `get_async_client`) returns a client backed by the embedded ClickHouse engine via the `chdb` Python package, with no server required. Install with `pip install 'clickhouse-connect[chdb]'`. The existing `NativeTransform` byte parser is reused, so type handling, DB-API, and SQLAlchemy paths work unchanged. Linux and macOS only; HTTP-only kwargs (auth, TLS, proxy, retry) and `external_data` are no-ops or not supported. +- Add an in-process `chdb` backend. `clickhouse_connect.get_client(interface="chdb")` (and `get_async_client`) returns a client backed by the embedded ClickHouse engine via the `chdb` Python package, with no server required. ### Bug Fixes - Async client: `ca_cert="certifi"` shorthand now resolves to `certifi.where()`, matching the sync client. Previously the async path passed the literal string to `ssl_context.load_verify_locations`, producing `FileNotFoundError`. Closes [#742](https://github.com/ClickHouse/clickhouse-connect/issues/742) diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index 626f4280..4e616911 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -340,12 +340,7 @@ def _create_chdb_client( generic_args: dict[str, Any] | None, kwargs: dict[str, Any], ) -> Client: - try: - from clickhouse_connect.driver.chdbclient import ChdbClient - except ImportError as ex: - if ex.name and ex.name.startswith("chdb"): - raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex - raise + from clickhouse_connect.driver.chdbclient import ChdbClient settings = dict(settings or {}) if generic_args: @@ -367,12 +362,7 @@ def _create_chdb_async_client( generic_args: dict[str, Any] | None, kwargs: dict[str, Any], ): - try: - from clickhouse_connect.driver.chdbasync import AsyncChdbClient - except ImportError as ex: - if ex.name and ex.name.startswith("chdb"): - raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex - raise + from clickhouse_connect.driver.chdbasync import AsyncChdbClient sync_client = _create_chdb_client(database=database, settings=settings, generic_args=generic_args, kwargs=kwargs) return AsyncChdbClient(sync_client) # type: ignore[arg-type] diff --git a/clickhouse_connect/driver/chdbclient.py b/clickhouse_connect/driver/chdbclient.py index dba4c025..fe9fe79b 100644 --- a/clickhouse_connect/driver/chdbclient.py +++ b/clickhouse_connect/driver/chdbclient.py @@ -123,10 +123,7 @@ def __init__( if sys.platform.startswith("win"): raise NotSupportedError("chdb backend is not supported on Windows") - try: - import chdb - except ImportError as ex: - raise ImportError("chdb backend requires the chdb package. Install with: pip install 'clickhouse-connect[chdb]'") from ex + import chdb self._chdb_path = chdb_path or ":memory:" self._chdb_options = dict(chdb_options) if chdb_options else {} @@ -246,12 +243,17 @@ def _restore_settings(self, snapshot: dict[str, tuple[str, bool]]) -> None: except Exception: # noqa: BLE001 logger.debug("Failed to restore setting %s after command()", name, exc_info=True) - def _exec_raw_query(self, sql: str, fmt: str = "Native") -> bytes: + @staticmethod + def _strip_param_prefix(bind_params: dict[str, Any]) -> dict[str, Any]: + """chdb's `params` kwarg expects bare names (`x`); bind_query produces `param_x`.""" + return {(k[6:] if k.startswith("param_") else k): v for k, v in bind_params.items()} if bind_params else {} + + def _exec_raw_query(self, sql: str, fmt: str = "Native", params: dict[str, Any] | None = None) -> bytes: """Run a query against chdb under the per-client lock and return raw bytes.""" self._ensure_open() with self._lock: try: - result = self._conn.query(sql, fmt) + result = self._conn.query(sql, fmt, params=params or {}) except Exception as ex: # noqa: BLE001 raise self._wrap_exception(ex) from ex return result.bytes() if hasattr(result, "bytes") else bytes(result) @@ -293,15 +295,16 @@ def _query_with_context(self, context: QueryContext) -> QueryResult: final_query = self._prep_query(context) if isinstance(final_query, bytes): final_query = final_query.decode() + params = self._strip_param_prefix(context.bind_params) if context.is_insert: # INSERT ... VALUES carries its data inline and has no result block to parse; # appending `FORMAT Native` to a VALUES statement is a syntax error. sql = self._append_settings_clause(final_query, self._filter_per_call_settings(context.settings)) - self._exec_raw_query(sql, "TabSeparated") + self._exec_raw_query(sql, "TabSeparated", params=params) return QueryResult([]) sql = f"{final_query}\n FORMAT Native" sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) - data = self._exec_raw_query(sql, "Native") + data = self._exec_raw_query(sql, "Native", params=params) byte_source = RespBuffCls(_BytesSource(data)) query_result = self._transform.parse_response(byte_source, context) query_result.summary = {} @@ -319,11 +322,11 @@ def raw_query( ) -> bytes: if external_data is not None: raise NotSupportedError("external_data is not supported by the chdb backend") - final_query, _ = bind_query(query, parameters, self.server_tz) + final_query, bound = bind_query(query, parameters, self.server_tz) if isinstance(final_query, bytes): final_query = final_query.decode() final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) - return self._exec_raw_query(final_query, fmt or "Native") + return self._exec_raw_query(final_query, fmt or "Native", params=self._strip_param_prefix(bound)) def raw_stream( self, @@ -337,16 +340,17 @@ def raw_stream( ) -> io.IOBase: if external_data is not None: raise NotSupportedError("external_data is not supported by the chdb backend") - final_query, _ = bind_query(query, parameters, self.server_tz) + final_query, bound = bind_query(query, parameters, self.server_tz) if isinstance(final_query, bytes): final_query = final_query.decode() final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) + params = self._strip_param_prefix(bound) self._ensure_open() # Acquire the lock for the lifetime of the streaming read so concurrent # callers don't interleave queries on the same chdb connection. self._lock.acquire() try: - streaming = self._conn.send_query(final_query, fmt or "Native") + streaming = self._conn.send_query(final_query, fmt or "Native", params=params or {}) except Exception as ex: # noqa: BLE001 self._lock.release() raise self._wrap_exception(ex) from ex @@ -364,9 +368,10 @@ def command( ) -> str | int | Sequence[str] | QuerySummary: if external_data is not None: raise NotSupportedError("external_data is not supported by the chdb backend") - cmd, _ = bind_query(cmd, parameters, self.server_tz) + cmd, bound = bind_query(cmd, parameters, self.server_tz) if isinstance(cmd, bytes): cmd = cmd.decode() + params = self._strip_param_prefix(bound) if data is not None: if isinstance(data, bytes): data_str = data.decode() @@ -383,7 +388,7 @@ def command( for k, v in per_call.items(): self._persist_setting(k, v) try: - body = self._exec_raw_query(cmd, self._format_for_command()) + body = self._exec_raw_query(cmd, self._format_for_command(), params=params) finally: if snapshot: self._restore_settings(snapshot) diff --git a/setup.py b/setup.py index 9d5cd175..d111a410 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ def run_setup(try_c: bool = True): install_requires=[ "certifi", "urllib3>=1.26", + 'chdb>=4.1.7; sys_platform != "win32"', 'tzdata; sys_platform == "win32"', 'zstandard; python_version<"3.14"', 'zstandard>=0.25.0; python_version>="3.14"', @@ -75,7 +76,6 @@ def run_setup(try_c: bool = True): "tzlocal": ["tzlocal>=4.0"], "tzdata": ["tzdata"], "async": ["aiohttp>=3.8.0"], - "chdb": ["chdb>=4.1.7"], }, tests_require=["pytest"], entry_points={ diff --git a/tests/test_bare_install.py b/tests/test_bare_install.py index 4e3f6bf4..16c037d7 100644 --- a/tests/test_bare_install.py +++ b/tests/test_bare_install.py @@ -1,21 +1,6 @@ -import importlib.util - import clickhouse_connect -def test_chdb_backend_missing_dep_raises_clean_error(): - """Without chdb installed, interface='chdb' must raise a clean ImportError.""" - if importlib.util.find_spec("chdb") is not None: - print("chdb is installed; skipping missing-dep error path check") - return - try: - clickhouse_connect.get_client(interface="chdb") - except ImportError as ex: - assert "chdb" in str(ex), f"expected chdb in error message, got: {ex}" - return - raise AssertionError("Expected ImportError when chdb is not installed") - - def test_bare_install(): """Bare install test to validate the package works with only core dependencies""" client = clickhouse_connect.get_client() @@ -35,5 +20,4 @@ def test_bare_install(): if __name__ == "__main__": - test_chdb_backend_missing_dep_raises_clean_error() test_bare_install() diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index 18490143..a2050b4a 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -24,5 +24,4 @@ lz4>=4.4.5; python_version >= "3.14" pyjwt[crypto]==2.10.1 pre-commit==4.3.0 ruff==0.15.8 -chdb>=4.1.7 \ No newline at end of file diff --git a/tests/unit_tests/test_driver/test_chdbclient.py b/tests/unit_tests/test_driver/test_chdbclient.py index d45644e2..8a96a186 100644 --- a/tests/unit_tests/test_driver/test_chdbclient.py +++ b/tests/unit_tests/test_driver/test_chdbclient.py @@ -9,17 +9,22 @@ from __future__ import annotations import asyncio +import io +import os from datetime import date, datetime from decimal import Decimal +from uuid import UUID import pytest chdb = pytest.importorskip("chdb") import clickhouse_connect # noqa: E402 +from clickhouse_connect.driver.chdbclient import _build_conn_string, _format_error_message # noqa: E402 from clickhouse_connect.driver.exceptions import ( # noqa: E402 DatabaseError, NotSupportedError, + ProgrammingError, ) @@ -30,11 +35,6 @@ def client(): c.close() -@pytest.fixture -def async_client(): - return clickhouse_connect.get_async_client - - # ---- basic protocol ---- @@ -344,3 +344,379 @@ def test_factory_dispatches_on_interface(): assert isinstance(c, ChdbClient) finally: c.close() + + +# ---- pure helper unit tests (no chdb instance needed) ---- + + +def test_build_conn_string_default_memory(): + assert _build_conn_string("", None) == ":memory:" + assert _build_conn_string(None, None) == ":memory:" # type: ignore[arg-type] + + +def test_build_conn_string_path_unchanged_without_options(): + assert _build_conn_string("/data/db", None) == "/data/db" + assert _build_conn_string("file:/data/db?mode=ro", None) == "file:/data/db?mode=ro" + + +def test_build_conn_string_appends_options(): + assert _build_conn_string("/data/db", {"mode": "ro"}) == "/data/db?mode=ro" + + +def test_build_conn_string_merges_with_existing_query(): + result = _build_conn_string("file:/data/db?already=set", {"max_threads": 4}) + assert "already=set" in result and "max_threads=4" in result and "&" in result + + +def test_format_error_message_extracts_code_prefix(): + raw = "Some prefix\nCode: 46. DB::Exception: Function with name `bad` does not exist." + assert _format_error_message(raw).startswith("Code: 46.") + + +def test_format_error_message_passes_through_plain_text(): + assert _format_error_message("plain error") == "plain error" + assert _format_error_message("") == "" + + +# ---- closed client and lifecycle ---- + + +def test_query_after_close_raises(): + c = clickhouse_connect.get_client(interface="chdb") + c.close() + with pytest.raises(ProgrammingError): + c.query("SELECT 1") + + +def test_close_is_idempotent(): + c = clickhouse_connect.get_client(interface="chdb") + c.close() + c.close() # must not raise + + +def test_close_connections_closes_client(): + c = clickhouse_connect.get_client(interface="chdb") + c.close_connections() + with pytest.raises(ProgrammingError): + c.query("SELECT 1") + + +def test_context_manager_closes_client(): + with clickhouse_connect.get_client(interface="chdb") as c: + assert c.ping() is True + with pytest.raises(ProgrammingError): + c.query("SELECT 1") + + +# ---- chdb_path persistence ---- + + +def test_chdb_path_persists_across_clients(tmp_path): + db_path = str(tmp_path / "persisted.db") + + a = clickhouse_connect.get_client(interface="chdb", chdb_path=db_path) + try: + a.command("CREATE TABLE persisted (id UInt32) ENGINE = MergeTree ORDER BY id") + a.insert("persisted", [[13], [79]], column_names=["id"]) + finally: + a.close() + + b = clickhouse_connect.get_client(interface="chdb", chdb_path=db_path) + try: + rows = b.query("SELECT id FROM persisted ORDER BY id").result_rows + assert rows == [(13,), (79,)] + finally: + b.close() + + +# ---- per-call settings on query / insert ---- + + +def test_per_call_settings_appended_to_select(client): + # Setting that affects output rather than just performance, so we can verify it + # actually reached chdb. `output_format_decimal_trailing_zeros` controls Decimal + # text formatting, but for verification we use a behavior we can observe. + r = client.query("SELECT number FROM numbers(10)", settings={"max_block_size": 3}) + assert [row[0] for row in r.result_rows] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + +def test_per_call_settings_do_not_leak_via_query(client): + before = _read_session_setting(client, "max_block_size") + client.query("SELECT 1", settings={"max_block_size": 17}) + after = _read_session_setting(client, "max_block_size") + # query path uses inline SETTINGS clause (not SET), so it should never modify + # the session value at all. + assert after == before + + +# ---- show_clickhouse_errors ---- + + +def test_show_clickhouse_errors_false_sanitizes_message(): + c = clickhouse_connect.get_client(interface="chdb", show_clickhouse_errors=False) + try: + with pytest.raises(DatabaseError) as ex_info: + c.query("SELECT bad_function()") + assert "UNKNOWN_FUNCTION" not in str(ex_info.value) + assert "bad_function" not in str(ex_info.value) + finally: + c.close() + + +# ---- query_limit ---- + + +def test_query_limit_auto_appends_limit(): + c = clickhouse_connect.get_client(interface="chdb", query_limit=3) + try: + rows = c.query("SELECT number FROM numbers(100)").result_rows + assert len(rows) == 3 + finally: + c.close() + + +def test_explicit_limit_not_overridden_by_query_limit(): + c = clickhouse_connect.get_client(interface="chdb", query_limit=3) + try: + rows = c.query("SELECT number FROM numbers(100) LIMIT 7").result_rows + assert len(rows) == 7 + finally: + c.close() + + +# ---- streaming variations ---- + + +def test_raw_stream_via_context_manager(client): + with client.raw_stream("SELECT number FROM numbers(5)", fmt="CSV") as stream: + data = stream.read() + assert data == b"0\n1\n2\n3\n4\n" + + +def test_raw_stream_chunked_read(client): + stream = client.raw_stream("SELECT number FROM numbers(50)", fmt="CSV") + try: + out = b"" + while chunk := stream.read(8): + out += chunk + finally: + stream.close() + assert out == b"".join(f"{n}\n".encode() for n in range(50)) + + +def test_raw_stream_readinto(client): + stream = client.raw_stream("SELECT number FROM numbers(3)", fmt="CSV") + try: + buf = bytearray(64) + n = stream.readinto(buf) + assert buf[:n] == b"0\n1\n2\n" + finally: + stream.close() + + +def test_stream_release_lock_on_close(client): + # If close() doesn't release the lock, the next query would deadlock. + stream = client.raw_stream("SELECT 1", fmt="CSV") + stream.close() + # Should return immediately, no deadlock: + assert client.query("SELECT 1").result_rows == [(1,)] + + +# ---- raw_insert input shapes ---- + + +def test_raw_insert_accepts_str(client): + client.command("CREATE TABLE raw_str (id UInt32, v String) ENGINE = Memory") + client.raw_insert("raw_str", insert_block="13,user_1\n79,user_2\n", fmt="CSV") + r = client.query("SELECT id, v FROM raw_str ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_raw_insert_accepts_file_like(client): + client.command("CREATE TABLE raw_file (id UInt32, v String) ENGINE = Memory") + buf = io.BytesIO(b"13,user_1\n79,user_2\n") + client.raw_insert("raw_file", insert_block=buf, fmt="CSV") + r = client.query("SELECT id, v FROM raw_file ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_raw_insert_accepts_generator(client): + client.command("CREATE TABLE raw_gen (id UInt32, v String) ENGINE = Memory") + + def chunks(): + yield b"13,user_1\n" + yield b"79,user_2\n" + + client.raw_insert("raw_gen", insert_block=chunks(), fmt="CSV") + r = client.query("SELECT id, v FROM raw_gen ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_raw_insert_compression_rejected(client): + client.command("CREATE TABLE raw_compress (id UInt32) ENGINE = Memory") + with pytest.raises(NotSupportedError): + client.raw_insert("raw_compress", insert_block=b"1\n", fmt="CSV", compression="lz4") + + +def test_raw_insert_missing_args(client): + with pytest.raises(ProgrammingError): + client.raw_insert(None, insert_block=b"x") # type: ignore[arg-type] + with pytest.raises(ProgrammingError): + client.raw_insert("t", insert_block=None) + + +def test_raw_insert_cleans_up_temp_file(client, monkeypatch): + """Verify the temp file is deleted even when chdb errors.""" + client.command("CREATE TABLE raw_cleanup (id UInt32) ENGINE = Memory") + seen_paths = [] + + import tempfile as _tempfile + + original = _tempfile.NamedTemporaryFile + + def tracking(*args, **kwargs): + f = original(*args, **kwargs) + seen_paths.append(f.name) + return f + + monkeypatch.setattr(_tempfile, "NamedTemporaryFile", tracking) + + # Bad CSV content for an UInt32 column will cause chdb to error. + with pytest.raises(DatabaseError): + client.raw_insert("raw_cleanup", insert_block=b"not_a_number\n", fmt="CSV") + + assert seen_paths, "temp file path not captured" + for p in seen_paths: + assert not os.path.exists(p), f"temp file leaked: {p}" + + +# ---- additional types ---- + + +def test_query_tuple_and_fixed_string(client): + r = client.query("SELECT tuple(1, 'a', 3.14) AS t, toFixedString('xyz', 4) AS fs") + t, fs = r.result_rows[0] + assert t == (1, "a", 3.14) + assert fs == b"xyz\x00" + + +def test_query_uuid(client): + val = "550e8400-e29b-41d4-a716-446655440000" + r = client.query(f"SELECT toUUID('{val}') AS u") + assert r.result_rows == [(UUID(val),)] + + +def test_query_ipv4_ipv6(client): + r = client.query("SELECT toIPv4('127.0.0.1') AS v4, toIPv6('::1') AS v6") + v4, v6 = r.result_rows[0] + import ipaddress + + assert v4 == ipaddress.IPv4Address("127.0.0.1") + assert v6 == ipaddress.IPv6Address("::1") + + +def test_query_enum(client): + r = client.query("SELECT CAST('a' AS Enum8('a' = 1, 'b' = 2)) AS e") + assert r.result_rows == [("a",)] + + +def test_query_datetime64_with_tz(client): + r = client.query("SELECT toDateTime64('2026-05-19 10:30:00.123456', 6, 'America/New_York') AS dt") + (dt,) = r.result_rows[0] + assert dt.year == 2026 and dt.microsecond == 123456 + + +def test_query_nan_handling(client): + r = client.query("SELECT CAST('nan' AS Float64) AS x, CAST('-inf' AS Float64) AS y") + x, y = r.result_rows[0] + assert x != x # NaN + assert y == float("-inf") + + +# ---- parameter binding ---- + + +def test_query_with_parameters(client): + r = client.query("SELECT {x:Int32} AS x, {name:String} AS name", parameters={"x": 13, "name": "user_1"}) + assert r.result_rows == [(13, "user_1")] + + +# ---- transport-only settings don't get persisted ---- + + +def test_transport_only_setting_not_persisted_to_session(client): + # session_id is a transport-only key; ChdbClient should accept it but NOT emit + # SET session_id=... to chdb (which would either error or apply a meaningless setting). + before = _read_session_setting(client, "session_id") + client.set_client_setting("session_id", "abc-123") + after = _read_session_setting(client, "session_id") + assert after == before + # But the recorded client-side value is kept for inspection + assert client.get_client_setting("session_id") == "abc-123" + + +# ---- DataFrame stream ---- + + +def test_query_df_stream(client): + pytest.importorskip("pandas") + client.command("CREATE TABLE df_stream (id UInt32) ENGINE = Memory") + client.insert("df_stream", [[i] for i in range(20)], column_names=["id"]) + with client.query_df_stream("SELECT id FROM df_stream SETTINGS max_block_size = 5") as stream: + frames = list(stream) + total = sum(len(f) for f in frames) + assert total == 20 + + +# ---- async additional coverage ---- + + +def test_async_external_data_rejected(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + from clickhouse_connect.driver.external import ExternalData + + ext = ExternalData(file_name="x.csv", data=b"1\n", fmt="CSV", structure="id UInt32") + with pytest.raises(NotSupportedError): + await c.query("SELECT * FROM x", external_data=ext) + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_query_error_propagates_as_database_error(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + with pytest.raises(DatabaseError): + await c.query("SELECT bad_function()") + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_closed_client_query_raises(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + await c.close() + with pytest.raises(ProgrammingError): + await c.query("SELECT 1") + + asyncio.run(run()) + + +def test_async_set_client_setting_is_sync(client): + # Async client's set_client_setting is intentionally sync (no I/O wrap) for + # symmetry with HTTP AsyncClient. + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + c.set_client_setting("max_block_size", 99) # NOT awaited + assert c.get_client_setting("max_block_size") == "99" + finally: + await c.close() + + asyncio.run(run()) From 889a30742393ee88d2fad341aedd61eedb9149e6 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 12:27:07 +0000 Subject: [PATCH 7/8] Add chdb tests for arrow/numpy/streaming/DBAPI parity with HTTP --- clickhouse_connect/driver/chdbclient.py | 5 +- tests/test_bare_install.py | 4 - .../unit_tests/test_driver/test_chdbclient.py | 154 ++++++++++++++++++ 3 files changed, 158 insertions(+), 5 deletions(-) diff --git a/clickhouse_connect/driver/chdbclient.py b/clickhouse_connect/driver/chdbclient.py index fe9fe79b..b789a909 100644 --- a/clickhouse_connect/driver/chdbclient.py +++ b/clickhouse_connect/driver/chdbclient.py @@ -447,10 +447,13 @@ def raw_insert( tmp = tempfile.NamedTemporaryFile(suffix=f".{fmt.lower()}", delete=False) try: try: - if isinstance(insert_block, (bytes, bytearray)): + if isinstance(insert_block, (bytes, bytearray, memoryview)): tmp.write(bytes(insert_block)) elif isinstance(insert_block, str): tmp.write(insert_block.encode()) + elif hasattr(insert_block, "to_pybytes"): + # pyarrow.Buffer and friends — buffer protocol holder + tmp.write(insert_block.to_pybytes()) elif hasattr(insert_block, "read"): while True: chunk = insert_block.read(1 << 20) diff --git a/tests/test_bare_install.py b/tests/test_bare_install.py index 16c037d7..59a29c09 100644 --- a/tests/test_bare_install.py +++ b/tests/test_bare_install.py @@ -17,7 +17,3 @@ def test_bare_install(): res = client.query("SELECT * FROM _bare_install_test ORDER BY id") assert res.result_rows == [(1, "a"), (2, "b")], f"unexpected: {res.result_rows}" client.command("DROP TABLE _bare_install_test") - - -if __name__ == "__main__": - test_bare_install() diff --git a/tests/unit_tests/test_driver/test_chdbclient.py b/tests/unit_tests/test_driver/test_chdbclient.py index 8a96a186..abc0c7c4 100644 --- a/tests/unit_tests/test_driver/test_chdbclient.py +++ b/tests/unit_tests/test_driver/test_chdbclient.py @@ -256,6 +256,140 @@ def test_set_access_token_silent_noop(client): client.set_access_token("not-a-real-token") # must not raise +# ---- pyarrow / numpy round-trips ---- + + +def test_query_arrow(client): + pytest.importorskip("pyarrow") + client.command("CREATE TABLE arrow_q (id UInt32, name String) ENGINE = Memory") + client.insert("arrow_q", [[13, "user_1"], [79, "user_2"]], column_names=["id", "name"]) + table = client.query_arrow("SELECT id, name FROM arrow_q ORDER BY id") + assert table.column_names == ["id", "name"] + assert table.column("id").to_pylist() == [13, 79] + assert table.column("name").to_pylist() == ["user_1", "user_2"] + + +def test_query_arrow_stream(client): + pytest.importorskip("pyarrow") + client.command("CREATE TABLE arrow_qs (id UInt32) ENGINE = Memory") + client.insert("arrow_qs", [[i] for i in range(20)], column_names=["id"]) + with client.query_arrow_stream("SELECT id FROM arrow_qs SETTINGS max_block_size = 5") as stream: + batches = list(stream) + assert sum(b.num_rows for b in batches) == 20 + + +def test_insert_arrow_round_trip(client): + pa = pytest.importorskip("pyarrow") + client.command("CREATE TABLE arrow_ins (id UInt32, name String) ENGINE = Memory") + table = pa.table({"id": pa.array([13, 79], type=pa.uint32()), "name": pa.array(["user_1", "user_2"])}) + client.insert_arrow("arrow_ins", table) + r = client.query("SELECT id, name FROM arrow_ins ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_query_np(client): + pytest.importorskip("numpy") + client.command("CREATE TABLE np_q (id UInt32, v Float64) ENGINE = Memory") + client.insert("np_q", [[13, 1.5], [79, 2.5]], column_names=["id", "v"]) + arr = client.query_np("SELECT id, v FROM np_q ORDER BY id") + assert list(arr["id"]) == [13, 79] + assert list(arr["v"]) == [1.5, 2.5] + + +def test_query_np_stream(client): + pytest.importorskip("numpy") + client.command("CREATE TABLE np_qs (id UInt32) ENGINE = Memory") + client.insert("np_qs", [[i] for i in range(20)], column_names=["id"]) + with client.query_np_stream("SELECT id FROM np_qs SETTINGS max_block_size = 7") as stream: + chunks = list(stream) + assert sum(len(c) for c in chunks) == 20 + + +# ---- additional streaming flavors ---- + + +def test_query_column_block_stream(client): + client.command("CREATE TABLE col_stream (id UInt32, v String) ENGINE = Memory") + client.insert("col_stream", [[i, f"row_{i}"] for i in range(15)], column_names=["id", "v"]) + with client.query_column_block_stream("SELECT id, v FROM col_stream SETTINGS max_block_size = 5") as stream: + blocks = list(stream) + # Each block is column-oriented: a tuple of columns + total_rows = sum(len(block[0]) for block in blocks) + assert total_rows == 15 + + +def test_query_rows_stream(client): + client.command("CREATE TABLE rows_stream (id UInt32) ENGINE = Memory") + client.insert("rows_stream", [[i] for i in range(10)], column_names=["id"]) + with client.query_rows_stream("SELECT id FROM rows_stream ORDER BY id") as stream: + rows = list(stream) + assert [r[0] for r in rows] == list(range(10)) + + +# ---- insert variations ---- + + +def test_insert_column_oriented(client): + client.command("CREATE TABLE col_oriented (id UInt32, v Float64) ENGINE = Memory") + columns = [[13, 79, 103], [1.5, 2.5, 3.5]] + client.insert("col_oriented", columns, column_names=["id", "v"], column_oriented=True) + r = client.query("SELECT id, v FROM col_oriented ORDER BY id") + assert r.result_rows == [(13, 1.5), (79, 2.5), (103, 3.5)] + + +def test_reusable_insert_context(client): + client.command("CREATE TABLE reuse_ctx (id UInt32, name String) ENGINE = Memory") + ctx = client.create_insert_context("reuse_ctx", column_names=["id", "name"]) + client.insert(data=[[13, "first"]], context=ctx) + client.insert(data=[[79, "second"]], context=ctx) + r = client.query("SELECT id, name FROM reuse_ctx ORDER BY id") + assert r.result_rows == [(13, "first"), (79, "second")] + + +# ---- database parameter ---- + + +def test_database_parameter_switches_default(): + c = clickhouse_connect.get_client(interface="chdb") + try: + c.command("CREATE DATABASE other_db") + c.command("CREATE TABLE other_db.scoped (id UInt32) ENGINE = Memory") + c.command("INSERT INTO other_db.scoped VALUES (13)") + finally: + c.close() + # Note: chdb :memory: is per-connection, so this test only checks the USE + # mechanism — can't cross sessions on :memory:. Instead verify USE works inline: + c2 = clickhouse_connect.get_client(interface="chdb") + try: + c2.command("CREATE DATABASE scoped_test") + c2.command("USE scoped_test") + c2.command("CREATE TABLE local_t (id UInt32) ENGINE = Memory") + # unqualified reference should resolve into scoped_test + c2.command("INSERT INTO local_t VALUES (13)") + assert c2.query("SELECT count() FROM local_t").result_rows[0][0] == 1 + assert c2.query("SELECT count() FROM scoped_test.local_t").result_rows[0][0] == 1 + finally: + c2.close() + + +def test_database_param_forwarded_to_use(tmp_path): + db = str(tmp_path / "dbparam.db") + # First connection creates DB + table + a = clickhouse_connect.get_client(interface="chdb", chdb_path=db) + try: + a.command("CREATE DATABASE analytics") + a.command("CREATE TABLE analytics.events (id UInt32) ENGINE = MergeTree ORDER BY id") + a.command("INSERT INTO analytics.events VALUES (13)") + finally: + a.close() + # Second connection uses the database= kwarg; unqualified table reference must work + b = clickhouse_connect.get_client(interface="chdb", chdb_path=db, database="analytics") + try: + assert b.query("SELECT count() FROM events").result_rows[0][0] == 1 + finally: + b.close() + + # ---- DBAPI on top of chdb ---- @@ -278,6 +412,26 @@ def test_dbapi_cursor_round_trip(): conn.close() +def test_dbapi_executemany(): + import clickhouse_connect.dbapi as dbapi + + conn = dbapi.connect(interface="chdb") + try: + cur = conn.cursor() + try: + cur.execute("CREATE TABLE dba_many (id UInt32, name String) ENGINE = Memory") + cur.executemany( + "INSERT INTO dba_many (id, name) VALUES", + [{"id": 13, "name": "user_1"}, {"id": 79, "name": "user_2"}, {"id": 103, "name": "user_3"}], + ) + cur.execute("SELECT id, name FROM dba_many ORDER BY id") + assert cur.fetchall() == [(13, "user_1"), (79, "user_2"), (103, "user_3")] + finally: + cur.close() + finally: + conn.close() + + # ---- async client ---- From bdb483c8b92dd109e2efc9d70ea27a07eed7952a Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 21 May 2026 15:10:26 +0000 Subject: [PATCH 8/8] Fix chdb backend bugs and behavior gaps found via HTTP integration test parity --- clickhouse_connect/driver/chdbasync.py | 4 +- clickhouse_connect/driver/chdbclient.py | 235 ++++++++++++++++-- .../unit_tests/test_driver/test_chdbclient.py | 161 +++++++++++- 3 files changed, 380 insertions(+), 20 deletions(-) diff --git a/clickhouse_connect/driver/chdbasync.py b/clickhouse_connect/driver/chdbasync.py index 8dc9d58f..ee945adb 100644 --- a/clickhouse_connect/driver/chdbasync.py +++ b/clickhouse_connect/driver/chdbasync.py @@ -310,5 +310,5 @@ async def __aexit__(self, exc_type, exc, tb): async def create_insert_context(self, *args, **kwargs) -> InsertContext: # type: ignore[override] return await self._run(lambda: self._sync.create_insert_context(*args, **kwargs)) - async def create_query_context(self, *args, **kwargs) -> QueryContext: # type: ignore[override] - return await self._run(lambda: self._sync.create_query_context(*args, **kwargs)) + def create_query_context(self, *args, **kwargs) -> QueryContext: + return self._sync.create_query_context(*args, **kwargs) diff --git a/clickhouse_connect/driver/chdbclient.py b/clickhouse_connect/driver/chdbclient.py index b789a909..40cbd92d 100644 --- a/clickhouse_connect/driver/chdbclient.py +++ b/clickhouse_connect/driver/chdbclient.py @@ -10,8 +10,10 @@ from __future__ import annotations import io +import json import logging import os +import re import sys import tempfile import threading @@ -19,6 +21,7 @@ from typing import TYPE_CHECKING, Any, BinaryIO from clickhouse_connect import common +from clickhouse_connect.datatypes.registry import get_from_name from clickhouse_connect.driver.binding import bind_query, quote_identifier from clickhouse_connect.driver.client import Client from clickhouse_connect.driver.common import coerce_int @@ -27,6 +30,7 @@ DatabaseError, NotSupportedError, ProgrammingError, + StreamFailureError, ) from clickhouse_connect.driver.external import ExternalData from clickhouse_connect.driver.insert import InsertContext @@ -39,6 +43,26 @@ logger = logging.getLogger(__name__) +_columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) + +# chdb's `send_query` emits each ClickHouse block as a self-contained encoding in the +# requested format. For formats that have row-level (or block-level) self-description +# and no global header/footer/file structure, concatenating chunks yields a valid +# stream the caller's parser can consume directly. Other formats (Arrow, Parquet, +# JSON, *WithNames variants, ...) would emit duplicated headers / multiple file +# markers per chunk, which is not a valid larger stream. For those we fall back to a +# single non-streaming query so the result is one well-formed payload. +_STREAM_SAFE_FORMATS = frozenset( + { + "Native", + "TabSeparated", + "TSV", + "CSV", + "RowBinary", + "JSONEachRow", + } +) + class _BytesSource: """ @@ -65,6 +89,58 @@ def close(self): return None +class _ChdbStreamSource: + """ + Source for `ResponseBuffer` backed by a chdb `StreamingResult`. Yields each + block's bytes and translates chdb's mid-stream RuntimeError into the + `StreamFailureError` clickhouse-connect callers expect. + """ + + __slots__ = ("_sr", "_lock", "_released", "last_message", "exception_tag") + + def __init__(self, streaming_result, lock: threading.Lock): + self._sr = streaming_result + self._lock = lock + self._released = False + self.last_message = None + self.exception_tag = None + + @property + def gen(self): + def _gen(): + try: + while True: + try: + chunk = next(self._sr) + except StopIteration: + return + except Exception as ex: # noqa: BLE001 + raise StreamFailureError(_format_error_message(str(ex))) from ex + payload = chunk.bytes() if hasattr(chunk, "bytes") else bytes(chunk) + if payload: + yield payload + finally: + self.close() + + return _gen() + + def close(self): + if self._released: + return + self._released = True + try: + close = getattr(self._sr, "close", None) + if close: + close() + except Exception: # noqa: BLE001 + logger.debug("Error closing chdb StreamingResult", exc_info=True) + finally: + try: + self._lock.release() + except RuntimeError: + pass + + def _format_error_message(message: str) -> str: """Extract a clean ClickHouse exception message from a chdb error string.""" if not message: @@ -75,6 +151,48 @@ def _format_error_message(message: str) -> str: return message.strip() +def _drain_to_bytes(block) -> bytes: + """Collect any supported insert_block shape into a single bytes value.""" + if isinstance(block, (bytes, bytearray, memoryview)): + return bytes(block) + if isinstance(block, str): + return block.encode() + if hasattr(block, "to_pybytes"): + return block.to_pybytes() + if hasattr(block, "read"): + return block.read() + parts = [] + for chunk in block: + parts.append(chunk if isinstance(chunk, (bytes, bytearray)) else chunk.encode()) + return b"".join(parts) + + +def _decompress(data: bytes, encoding: str) -> bytes: + if encoding == "lz4": + import lz4.frame + + return lz4.frame.decompress(data) + if encoding == "zstd": + import zstandard + + return zstandard.ZstdDecompressor().decompress(data) + if encoding == "gzip": + import gzip + + return gzip.decompress(data) + if encoding == "br": + try: + import brotli + except ImportError as ex: + raise NotSupportedError("brotli is required to decompress 'br' for chdb raw_insert") from ex + return brotli.decompress(data) + if encoding == "deflate": + import zlib + + return zlib.decompress(data) + raise NotSupportedError(f"Unsupported compression {encoding!r} for chdb raw_insert") + + def _build_conn_string(chdb_path: str, chdb_options: dict[str, Any] | None) -> str: path = chdb_path or ":memory:" if not chdb_options: @@ -92,6 +210,7 @@ class ChdbClient(Client): # HTTP-style transport settings: accepted by setting validation but stripped # before being forwarded to chdb (they have no in-process equivalent). valid_transport_settings: set[str] = { + "database", "client_protocol_version", "session_id", "session_timeout", @@ -194,10 +313,29 @@ def _filter_per_call_settings(self, settings: dict[str, Any] | None) -> dict[str out[k] = str_v return out - def _append_settings_clause(self, sql: str, settings: dict[str, str]) -> str: + @staticmethod + def _quote_setting_value(value: str) -> str: + """SQL-quote a setting value so chdb sees the expected literal type. + + Without quotes chdb parses bare numeric-looking strings as UInt64; if the + setting is actually String-typed (e.g. `insert_deduplication_token`) this + triggers `Bad get: has UInt64, requested String`. ClickHouse coerces + single-quoted literals back to numeric types where needed, so quoting + unconditionally is safe. + """ + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + + def _append_settings_clause(self, sql, settings): if not settings: return sql - extras = ", ".join(f"{k} = {v}" for k, v in settings.items()) + extras = ", ".join(f"{k} = {self._quote_setting_value(v)}" for k, v in settings.items()) + if isinstance(sql, bytes): + # raw_query can receive a bytes SQL when binary parameter substitution + # produced non-UTF-8 byte sequences. chdb accepts bytes natively, so + # keep the bytes path and append the settings clause as bytes too. + sep = b", " if b" SETTINGS " in sql.upper() else b" SETTINGS " + return sql + sep + extras.encode() if " SETTINGS " in sql.upper(): return f"{sql}, {extras}" return f"{sql} SETTINGS {extras}" @@ -206,7 +344,7 @@ def _persist_setting(self, key: str, value: str) -> None: """Apply a setting to the underlying chdb session via SET.""" try: with self._lock: - self._conn.query(f"SET {key} = {value}", "TabSeparated") + self._conn.query(f"SET {key} = {self._quote_setting_value(value)}", "TabSeparated") except Exception as ex: # noqa: BLE001 logger.debug("Failed to apply SET %s=%s to chdb session: %s", key, value, ex) @@ -296,6 +434,11 @@ def _query_with_context(self, context: QueryContext) -> QueryResult: if isinstance(final_query, bytes): final_query = final_query.decode() params = self._strip_param_prefix(context.bind_params) + if not context.is_insert and _columns_only_re.search(context.uncommented_query): + # chdb emits zero Native bytes for a LIMIT 0 query, so the Native parser + # would return an empty result with no column metadata. Fetch the schema + # via JSON instead, matching the HTTP client's columns-only fast path. + return self._fetch_columns_only(context, final_query, params) if context.is_insert: # INSERT ... VALUES carries its data inline and has no result block to parse; # appending `FORMAT Native` to a VALUES statement is a syntax error. @@ -304,12 +447,47 @@ def _query_with_context(self, context: QueryContext) -> QueryResult: return QueryResult([]) sql = f"{final_query}\n FORMAT Native" sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) - data = self._exec_raw_query(sql, "Native", params=params) - byte_source = RespBuffCls(_BytesSource(data)) + if context.streaming: + # Use chdb's streaming `send_query` so mid-execution engine errors + # (e.g. throwIf, division by zero on row N) surface during result + # iteration as `StreamFailureError`, matching HTTP's contract. The + # non-streaming `conn.query` would raise eagerly and lose lazy-error + # semantics — we only opt into that for true streaming results, since + # holding the per-client lock for the lifetime of a non-iterated + # QueryResult would deadlock subsequent calls. + self._ensure_open() + self._lock.acquire() + try: + streaming = self._conn.send_query(sql, "Native", params=params or {}) + except Exception as ex: # noqa: BLE001 + self._lock.release() + raise self._wrap_exception(ex) from ex + byte_source = RespBuffCls(_ChdbStreamSource(streaming, self._lock)) + else: + data = self._exec_raw_query(sql, "Native", params=params) + byte_source = RespBuffCls(_BytesSource(data)) query_result = self._transform.parse_response(byte_source, context) query_result.summary = {} return query_result + def _fetch_columns_only(self, context: QueryContext, final_query: str, params: dict[str, Any]) -> QueryResult: + sql = self._append_settings_clause(f"{final_query}\n FORMAT JSON", self._filter_per_call_settings(context.settings)) + body = self._exec_raw_query(sql, "JSON", params=params) + meta = json.loads(body)["meta"] + renamer = context.column_renamer + names: list[str] = [] + types = [] + for col in meta: + name = col["name"] + if renamer is not None: + try: + name = renamer(name) + except Exception as ex: # noqa: BLE001 + logger.debug("Failed to rename column %s: %s", name, ex) + names.append(name) + types.append(get_from_name(col["type"])) + return QueryResult([], None, tuple(names), tuple(types)) + def raw_query( self, query: str, @@ -323,10 +501,11 @@ def raw_query( if external_data is not None: raise NotSupportedError("external_data is not supported by the chdb backend") final_query, bound = bind_query(query, parameters, self.server_tz) - if isinstance(final_query, bytes): - final_query = final_query.decode() + # chdb's conn.query accepts both str and bytes; preserve bytes when binary + # parameter substitution (e.g. `$xx$` placeholders) yields non-UTF-8 SQL. final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) - return self._exec_raw_query(final_query, fmt or "Native", params=self._strip_param_prefix(bound)) + # HTTP path defaults to server's TabSeparated when no fmt is provided. + return self._exec_raw_query(final_query, fmt or "TabSeparated", params=self._strip_param_prefix(bound)) def raw_stream( self, @@ -345,12 +524,19 @@ def raw_stream( final_query = final_query.decode() final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) params = self._strip_param_prefix(bound) + output_fmt = fmt or "TabSeparated" + if output_fmt not in _STREAM_SAFE_FORMATS: + # Formats with global structure (Arrow IPC, Parquet, JSON, *WithNames, ...) + # can't be assembled from chdb's per-block chunks. Fetch as a single + # well-formed payload and wrap as an in-memory stream. + data = self._exec_raw_query(final_query, output_fmt, params=params) + return io.BytesIO(data) self._ensure_open() # Acquire the lock for the lifetime of the streaming read so concurrent # callers don't interleave queries on the same chdb connection. self._lock.acquire() try: - streaming = self._conn.send_query(final_query, fmt or "Native", params=params or {}) + streaming = self._conn.send_query(final_query, output_fmt, params=params or {}) except Exception as ex: # noqa: BLE001 self._lock.release() raise self._wrap_exception(ex) from ex @@ -435,8 +621,14 @@ def raw_insert( ) -> QuerySummary: if insert_block is None or not table: raise ProgrammingError("raw_insert requires a table and insert_block") - if compression: - raise NotSupportedError("compression is not supported for raw_insert in chdb mode. Provide uncompressed bytes.") + if compression and compression != "identity": + # HTTP carries this via Content-Encoding so the server decompresses. + # chdb has no equivalent input stage, so the caller's pre-compressed + # bytes must be drained and decompressed in the client before being + # written to the INFILE temp file. + insert_block = _drain_to_bytes(insert_block) + insert_block = _decompress(insert_block, compression) + compression = None fmt = fmt or self._write_format cols = "" @@ -466,8 +658,11 @@ def raw_insert( finally: tmp.close() - sql = f"INSERT INTO {table}{cols} FROM INFILE '{tmp.name}' FORMAT {fmt}" - sql = self._append_settings_clause(sql, self._filter_per_call_settings(settings)) + per_call = self._filter_per_call_settings(settings) + settings_clause = ( + f" SETTINGS {', '.join(f'{k} = {self._quote_setting_value(v)}' for k, v in per_call.items())}" if per_call else "" + ) + sql = f"INSERT INTO {table}{cols} FROM INFILE '{tmp.name}'{settings_clause} FORMAT {fmt}" self._exec_raw_query(sql, "TabSeparated") return QuerySummary({}) finally: @@ -516,8 +711,11 @@ def _insert_via_infile(self, context: InsertContext) -> QuerySummary: tmp.close() cols = ", ".join(quote_identifier(c) for c in context.column_names) - sql = f"INSERT INTO {context.table} ({cols}) FROM INFILE '{tmp.name}' FORMAT Native" - sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) + per_call = self._filter_per_call_settings(context.settings) + settings_clause = ( + f" SETTINGS {', '.join(f'{k} = {self._quote_setting_value(v)}' for k, v in per_call.items())}" if per_call else "" + ) + sql = f"INSERT INTO {context.table} ({cols}) FROM INFILE '{tmp.name}'{settings_clause} FORMAT Native" self._exec_raw_query(sql, "TabSeparated") return QuerySummary({}) finally: @@ -563,6 +761,13 @@ def _pull(self) -> bytes: except StopIteration: self._eof = True return b"" + except Exception as ex: # noqa: BLE001 + # chdb wraps mid-stream engine errors as RuntimeError. Surface them + # as StreamFailureError so callers can catch them with the same + # exception type used by the HTTP backend's mid-stream failures. + msg = _format_error_message(str(ex)) + self._eof = True + raise StreamFailureError(msg) from ex payload = chunk.bytes() if hasattr(chunk, "bytes") else bytes(chunk) if payload: return payload diff --git a/tests/unit_tests/test_driver/test_chdbclient.py b/tests/unit_tests/test_driver/test_chdbclient.py index abc0c7c4..b6dfd3cf 100644 --- a/tests/unit_tests/test_driver/test_chdbclient.py +++ b/tests/unit_tests/test_driver/test_chdbclient.py @@ -216,6 +216,110 @@ def test_raw_stream_iterates(client): assert data.startswith(b"0\n") +# ---- raw_stream format dispatch ---- +# +# chdb's send_query emits each ClickHouse block as a self-contained payload, so only +# formats with no global header / footer / file marker can be concatenated chunk-by- +# chunk. For everything else raw_stream falls back to a non-streaming query that +# returns one well-formed payload. These tests pin both branches. + + +def _stream_full_bytes(client, sql, fmt): + stream = client.raw_stream(sql, fmt=fmt) + try: + return stream.read() + finally: + stream.close() + + +def _row_count(client, sql, fmt): + """Run as raw_query (single payload) and return total bytes for comparison.""" + return client.raw_query(sql, fmt=fmt) + + +# All values verified end-to-end: 200k rows is enough to force chdb to emit multiple +# blocks (max_block_size default is ~65k). +_LARGE_QUERY = "SELECT number AS id FROM numbers(200000)" + + +@pytest.mark.parametrize("fmt", ["Native", "TabSeparated", "CSV", "RowBinary", "JSONEachRow"]) +def test_raw_stream_safe_format_full_data(client, fmt): + """Stream-safe formats: concatenated chunks must equal the single-query payload.""" + streamed = _stream_full_bytes(client, _LARGE_QUERY, fmt) + full = _row_count(client, _LARGE_QUERY, fmt) + assert len(streamed) == len(full), f"{fmt}: streamed {len(streamed)} != full {len(full)}" + + +@pytest.mark.parametrize( + "fmt", + [ + "Arrow", + "ArrowStream", + "Parquet", + "TabSeparatedWithNames", + "CSVWithNames", + "RowBinaryWithNamesAndTypes", + ], +) +def test_raw_stream_unsafe_format_falls_back_to_single_payload(client, fmt): + """Unsafe formats fall back to non-streaming: result must equal single-query bytes.""" + streamed = _stream_full_bytes(client, _LARGE_QUERY, fmt) + full = _row_count(client, _LARGE_QUERY, fmt) + assert streamed == full, f"{fmt}: bytes differ — streamed={len(streamed)} vs full={len(full)}" + + +def test_raw_stream_unsafe_format_json_yields_one_object(client): + """JSON includes per-run statistics, so check structural equality rather than bytes.""" + import json as _json + + streamed = _json.loads(_stream_full_bytes(client, _LARGE_QUERY, "JSON")) + full = _json.loads(_row_count(client, _LARGE_QUERY, "JSON")) + assert streamed["meta"] == full["meta"] + assert streamed["data"] == full["data"] + assert "statistics" in streamed and "statistics" in full + + +def test_arrow_stream_yields_all_record_batches(client): + """Regression: large Arrow stream must surface every RecordBatch, not just the first.""" + pa = pytest.importorskip("pyarrow") + stream = client.raw_stream(_LARGE_QUERY, fmt="ArrowStream") + try: + reader = pa.ipc.open_stream(stream) + batches = list(reader) + finally: + stream.close() + total_rows = sum(b.num_rows for b in batches) + assert total_rows == 200000, f"Lost rows in arrow stream: got {total_rows}" + + +def test_parquet_stream_is_single_file(client): + """Regression: Parquet output must be one valid file, not multiple concatenated.""" + pa = pytest.importorskip("pyarrow") + import pyarrow.parquet as pq + + stream = client.raw_stream(_LARGE_QUERY, fmt="Parquet") + try: + data = stream.read() + finally: + stream.close() + table = pq.read_table(pa.BufferReader(data)) + assert table.num_rows == 200000 + + +def test_jsoneachrow_stream_iterates_chunks(client): + """JSONEachRow stays on the streaming path (per-line format), verify chunked read.""" + stream = client.raw_stream(_LARGE_QUERY, fmt="JSONEachRow") + try: + first = stream.read(1024) + rest = stream.read() + finally: + stream.close() + # First chunk should start with valid JSON object + assert first.startswith(b'{"id":'), f"unexpected start: {first[:40]!r}" + # Total bytes equal the non-streaming version + assert len(first) + len(rest) == len(_row_count(client, _LARGE_QUERY, "JSONEachRow")) + + # ---- error mapping ---- @@ -233,6 +337,18 @@ def test_external_data_not_supported(client): client.query("SELECT * FROM x", external_data=ext) +def test_mid_stream_exception_surfaces_as_stream_failure(client): + """Mid-stream chdb errors must be raised as StreamFailureError to match HTTP semantics.""" + from clickhouse_connect.driver.exceptions import StreamFailureError + + query = "SELECT throwIf(number = 100) FROM numbers(1000) SETTINGS max_block_size = 10" + with pytest.raises(StreamFailureError) as ex_info: + with client.query_row_block_stream(query) as stream: + for _ in stream: + pass + assert "throwIf" in str(ex_info.value) or "Code: 395" in str(ex_info.value) + + # ---- HTTP-only kwargs accepted silently ---- @@ -706,10 +822,34 @@ def chunks(): assert r.result_rows == [(13, "user_1"), (79, "user_2")] -def test_raw_insert_compression_rejected(client): - client.command("CREATE TABLE raw_compress (id UInt32) ENGINE = Memory") +@pytest.mark.parametrize("compression", ["lz4", "zstd", "gzip"]) +def test_raw_insert_decompresses_pre_compressed_payload(client, compression): + """raw_insert with `compression=` accepts compressed bytes and decompresses client-side.""" + import gzip + + import lz4.frame + import zstandard + + csv = b"13,user_1\n79,user_2\n" + encoded = { + "lz4": lz4.frame.compress(csv), + "zstd": zstandard.ZstdCompressor().compress(csv), + "gzip": gzip.compress(csv), + }[compression] + client.command(f"CREATE TABLE raw_compress_{compression} (id UInt32, v String) ENGINE = Memory") + client.raw_insert( + f"raw_compress_{compression}", + insert_block=encoded, + fmt="CSV", + compression=compression, + ) + r = client.query(f"SELECT id, v FROM raw_compress_{compression} ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_raw_insert_unsupported_compression_raises(client): with pytest.raises(NotSupportedError): - client.raw_insert("raw_compress", insert_block=b"1\n", fmt="CSV", compression="lz4") + client.raw_insert("t", insert_block=b"1\n", fmt="CSV", compression="snappy") def test_raw_insert_missing_args(client): @@ -795,6 +935,21 @@ def test_query_with_parameters(client): assert r.result_rows == [(13, "user_1")] +def test_raw_query_with_embedded_binary_parameter(client): + """`$name$` placeholders inline raw bytes — chdb accepts bytes SQL, no decode.""" + binary_params = {"$xx$": b"col1,col2\n100,700"} + result = client.raw_query("SELECT col2, col1 FROM format(CSVWithNames, $xx$)", parameters=binary_params) + assert result == b"700\t100\n" + + +def test_raw_query_embedded_binary_with_non_utf8_bytes(client): + """Non-UTF-8 bytes (e.g. binary file content) embedded in SQL must round-trip.""" + payload = b"col1,col2\n100,\xff\x92" + result = client.raw_query("SELECT col2 FROM format(CSVWithNames, $xx$)", parameters={"$xx$": payload}) + # The non-UTF-8 byte sequence must come back intact in the output. + assert b"\xff" in result or b"\xc3\xbf" in result + + # ---- transport-only settings don't get persisted ----