diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f4ed829..dc787784 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Improvements - SQLAlchemy: opt-in server-side bind parameters via `create_engine(url, server_side_params=True)`. The dialect then emits ClickHouse native `{name:Type}` / `{name:Array(Type)}` placeholders instead of client-side string interpolation. Off by default. Closes [#735](https://github.com/ClickHouse/clickhouse-connect/issues/735). - Added a `token_provider` client option (sync and async). It accepts a callable returning an access token string; the callable is invoked once for the initial token and again to fetch a fresh token whenever the server rejects the current one (authentication failure), retrying the request once. Mutually exclusive with `access_token` and `username`/`password`. +- Added a `headers` option to `create_client`/`create_async_client` for attaching custom HTTP headers to every request, including the initialization queries sent during client creation. Useful for HTTP gateways that require auth headers such as Cloudflare Access service tokens. ### Bug Fixes - A `datetime` bound to a server-side `{name:DateTime64(...)}` placeholder now keeps its sub-second precision instead of being truncated to seconds. The declared parameter type drives this, so no `_64` name suffix or manual `DT64Param` wrapper is needed, and it applies through `Array` and `Tuple` hints. Plain `DateTime` binds are unchanged. Closes [#739](https://github.com/ClickHouse/clickhouse-connect/issues/739). diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index ca302a31..5e920f7b 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -89,6 +89,20 @@ def _validate_access_token(access_token: str | None, token_provider: Callable[[] raise ProgrammingError("Cannot use both access_token and token_provider") +def _pop_headers_arg(headers: Any | None, kwargs: dict[str, Any]) -> Any | None: + """Hoist headers parsed through generic kwargs while preserving explicit headers.""" + if "headers" in kwargs: + kwargs_headers = kwargs.pop("headers") + if headers is None: + headers = kwargs_headers + return headers + + +def _validate_headers(headers: Any | None) -> None: + if headers is not None and not isinstance(headers, dict): + raise ProgrammingError("headers must be a dictionary of HTTP header names and values") + + def create_client( *, host: str | None = None, @@ -102,6 +116,7 @@ def create_client( secure: bool | str = False, dsn: str | None = None, settings: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, generic_args: dict[str, Any] | None = None, **kwargs, ) -> Client: @@ -127,6 +142,9 @@ def create_client( :param dsn: A string in standard DSN (Data Source Name) format. Other connection values (such as host or user) will be extracted from this string if not set otherwise. :param settings: ClickHouse server settings to be used with the session/every request + :param headers: Additional HTTP headers to send with every request. This can be used for proxy or gateway + authentication, such as Cloudflare Access service token headers. These headers are applied after driver defaults, + so they can intentionally override headers such as Authorization or User-Agent. :param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings. It is not recommended to use this parameter externally. @@ -174,6 +192,7 @@ def create_client( host, username, password, port, database, interface = _parse_connection_params( host, username, password, port, database, interface, secure, dsn, kwargs ) + headers = _pop_headers_arg(headers, kwargs) _validate_access_token(access_token, token_provider, username, password) settings = settings or {} @@ -181,7 +200,10 @@ def create_client( if generic_args: client_params = signature(HttpClient).parameters for name, value in generic_args.items(): - if name in client_params: + if name == "headers": + if headers is None: + headers = value + elif name in client_params: kwargs[name] = value elif name == "compression": if "compress" not in kwargs: @@ -196,6 +218,7 @@ def create_client( access_token = access_token or generic_access token_provider = token_provider or generic_token _validate_access_token(access_token, token_provider, username, password) + _validate_headers(headers) return HttpClient( interface, host, @@ -206,6 +229,7 @@ def create_client( access_token, token_provider=token_provider, settings=settings, + headers=headers, **kwargs, ) raise ProgrammingError(f"Unrecognized client type {interface}") @@ -224,6 +248,7 @@ async def create_async_client( secure: bool | str = False, dsn: str | None = None, settings: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, generic_args: dict[str, Any] | None = None, connector_limit: int = 100, connector_limit_per_host: int = 20, @@ -254,6 +279,9 @@ async def create_async_client( :param dsn: A string in standard DSN (Data Source Name) format. Other connection values (such as host or user) will be extracted from this string if not set otherwise. :param settings: ClickHouse server settings to be used with the session/every request + :param headers: Additional HTTP headers to send with every request. This can be used for proxy or gateway + authentication, such as Cloudflare Access service token headers. These headers are applied after driver defaults, + so they can intentionally override headers such as Authorization or User-Agent. :param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings. It is not recommended to use this parameter externally :param connector_limit: Maximum number of allowable connections to the server @@ -313,13 +341,17 @@ async def create_async_client( host, username, password, port, database, interface = _parse_connection_params( host, username, password, port, database, interface, secure, dsn, kwargs ) + headers = _pop_headers_arg(headers, kwargs) _validate_access_token(access_token, token_provider, username, password) settings = settings or {} if generic_args: client_params = signature(_AsyncClient).parameters for name, value in generic_args.items(): - if name in client_params: + if name == "headers": + if headers is None: + headers = value + elif name in client_params: kwargs[name] = value elif name == "compression": if "compress" not in kwargs: @@ -338,7 +370,7 @@ async def create_async_client( access_token = access_token or generic_access token_provider = token_provider or generic_token _validate_access_token(access_token, token_provider, username, password) - + _validate_headers(headers) client = _AsyncClient( interface=interface, host=host, @@ -349,6 +381,7 @@ async def create_async_client( access_token=access_token, token_provider=token_provider, settings=settings, + headers=headers, connector_limit=connector_limit, connector_limit_per_host=connector_limit_per_host, keepalive_timeout=keepalive_timeout, diff --git a/clickhouse_connect/driver/asyncclient.py b/clickhouse_connect/driver/asyncclient.py index 485bec48..eef16ed6 100644 --- a/clickhouse_connect/driver/asyncclient.py +++ b/clickhouse_connect/driver/asyncclient.py @@ -227,6 +227,7 @@ def __init__( autogenerate_query_id: bool | None = None, form_encode_query_params: bool = False, rename_response_column: str | None = None, + headers: dict[str, str] | None = None, ): """ Async HTTP Client using aiohttp. Initialization is handled via _initialize(). @@ -334,6 +335,8 @@ def __init__( self._reported_libs = set() self._last_pool_reset = None self.headers["User-Agent"] = self.headers["User-Agent"].replace("mode:sync;", "mode:async;") + if headers: + self.headers.update(headers) # Store aiohttp-specific params for deferred initialization self._compress_param = compress diff --git a/clickhouse_connect/driver/httpclient.py b/clickhouse_connect/driver/httpclient.py index 6daa3b06..e3737a84 100644 --- a/clickhouse_connect/driver/httpclient.py +++ b/clickhouse_connect/driver/httpclient.py @@ -106,6 +106,7 @@ def __init__( proxy_path: str = "", form_encode_query_params: bool = False, rename_response_column: str | None = None, + headers: dict[str, str] | None = None, ): """ Create an HTTP ClickHouse Connect client @@ -162,6 +163,8 @@ def __init__( self._reported_libs = set() self.headers["User-Agent"] = common.build_client_name(client_name) + if headers: + self.headers.update(headers) self._read_format = self._write_format = "Native" self._transform = NativeTransform() @@ -756,7 +759,12 @@ def ping(self) -> bool: See BaseClient doc_string for this method """ try: - response = self.http.request("GET", f"{self.url}/ping", timeout=3, preload_content=True) + headers = dict_copy(self.headers) + kwargs = {"headers": headers, "timeout": 3, "preload_content": True} + if self.server_host_name: + kwargs["assert_same_host"] = False + headers["Host"] = self.server_host_name + response = self.http.request("GET", f"{self.url}/ping", **kwargs) return 200 <= response.status < 300 except HTTPError: logger.debug("ping failed", exc_info=True) diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index de22f95a..4999bce4 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -58,6 +58,19 @@ def test_transport_settings(param_client, call): assert len(result.result_set) > 0 +def test_client_headers(client_factory, call): + client = client_factory( + headers={ + "CF-Access-Client-Id": "test_client_id", + "CF-Access-Client-Secret": "test_client_secret", + } + ) + + assert client.headers["CF-Access-Client-Id"] == "test_client_id" + assert client.headers["CF-Access-Client-Secret"] == "test_client_secret" + assert call(client.command, "SELECT 79") == 79 + + def test_none_database(param_client, call): old_db = param_client.database test_db = call(param_client.command, "select currentDatabase()") diff --git a/tests/unit_tests/test_driver/test_httpclient.py b/tests/unit_tests/test_driver/test_httpclient.py index 90e1c8c9..b2c50dad 100644 --- a/tests/unit_tests/test_driver/test_httpclient.py +++ b/tests/unit_tests/test_driver/test_httpclient.py @@ -1,10 +1,13 @@ import logging from typing import Any -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest -from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError +from clickhouse_connect.driver import create_async_client, create_client +from clickhouse_connect.driver.asyncclient import AsyncClient +from clickhouse_connect.driver.client import Client +from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError from clickhouse_connect.driver.external import ExternalData from clickhouse_connect.driver.httpclient import HttpClient, ex_header from clickhouse_connect.driver.query import QueryContext @@ -21,6 +24,152 @@ def create_mock_response(status=500, headers=None, data=None): return response +class TestHttpClientHeaders: + """Test client-level HTTP header configuration.""" + + def test_headers_are_available_during_initialization(self): + init_headers = {} + + def capture_headers(client, _tz_source): + init_headers.update(client.headers) + + with patch.object(Client, "_init_common_settings", autospec=True, side_effect=capture_headers): + HttpClient( + interface="http", + host="localhost", + port=8123, + username="default", + password="", + database="default", + headers={ + "CF-Access-Client-Id": "test_client_id", + "CF-Access-Client-Secret": "test_client_secret", + }, + ) + + assert init_headers["CF-Access-Client-Id"] == "test_client_id" + assert init_headers["CF-Access-Client-Secret"] == "test_client_secret" + assert "Authorization" in init_headers + assert "User-Agent" in init_headers + + def test_request_headers_override_client_headers(self): + response = create_mock_response(status=200) + pool_mgr = Mock() + pool_mgr.request.return_value = response + + with patch.object(Client, "_init_common_settings", autospec=True): + client = HttpClient( + interface="http", + host="localhost", + port=8123, + username="default", + password="", + database="default", + pool_mgr=pool_mgr, + headers={"X-Trace": "client", "X-Gateway": "cloudflare"}, + ) + + client._raw_request(b"", {}, headers={"X-Trace": "request"}) + + request_headers = pool_mgr.request.call_args.kwargs["headers"] + assert request_headers["X-Trace"] == "request" + assert request_headers["X-Gateway"] == "cloudflare" + assert request_headers["Authorization"] == client.headers["Authorization"] + assert request_headers["User-Agent"] == client.headers["User-Agent"] + + def test_ping_uses_client_headers(self): + response = create_mock_response(status=200) + pool_mgr = Mock() + pool_mgr.request.return_value = response + + with patch.object(Client, "_init_common_settings", autospec=True): + client = HttpClient( + interface="http", + host="localhost", + port=8123, + username="default", + password="", + database="default", + pool_mgr=pool_mgr, + server_host_name="clickhouse.example.com", + headers={"X-Gateway": "cloudflare"}, + ) + + assert client.ping() is True + + request_headers = pool_mgr.request.call_args.kwargs["headers"] + assert request_headers["X-Gateway"] == "cloudflare" + assert request_headers["Authorization"] == client.headers["Authorization"] + assert request_headers["User-Agent"] == client.headers["User-Agent"] + assert request_headers["Host"] == "clickhouse.example.com" + assert pool_mgr.request.call_args.kwargs["assert_same_host"] is False + + def test_dsn_headers_query_param_must_be_dict(self): + with pytest.raises(ProgrammingError, match="headers must be a dictionary"): + create_client(dsn="http://localhost:8123/default?headers=not_a_dict") + + def test_explicit_headers_override_dsn_headers_query_param(self): + init_headers = {} + + def capture_headers(client, _tz_source): + init_headers.update(client.headers) + + with patch.object(Client, "_init_common_settings", autospec=True, side_effect=capture_headers): + create_client( + dsn="http://localhost:8123/default?headers=not_a_dict", + headers={"X-Gateway": "cloudflare"}, + ) + + assert init_headers["X-Gateway"] == "cloudflare" + + +class TestAsyncClientHeaders: + """Test async client-level HTTP header configuration.""" + + @pytest.mark.asyncio + async def test_request_headers_override_client_headers(self): + client = AsyncClient( + interface="http", + host="localhost", + port=8123, + username="default", + password="", + database="default", + headers={"X-Trace": "client", "X-Gateway": "cloudflare"}, + ) + session = Mock() + session.closed = False + response = Mock() + response.status = 200 + response.headers = {} + session.request = AsyncMock(return_value=response) + client._session = session + + await client._raw_request(None, {}, headers={"X-Trace": "request"}) + + request_headers = session.request.call_args.kwargs["headers"] + assert request_headers["X-Trace"] == "request" + assert request_headers["X-Gateway"] == "cloudflare" + assert request_headers["Authorization"] == client.headers["Authorization"] + assert request_headers["User-Agent"] == client.headers["User-Agent"] + assert request_headers["Accept-Encoding"] == client.headers["Accept-Encoding"] + + @pytest.mark.asyncio + async def test_dsn_headers_query_param_must_be_dict(self): + with pytest.raises(ProgrammingError, match="headers must be a dictionary"): + await create_async_client(dsn="http://localhost:8123/default?headers=not_a_dict") + + @pytest.mark.asyncio + async def test_explicit_headers_override_dsn_headers_query_param(self): + with patch.object(AsyncClient, "_initialize", new=AsyncMock()): + client = await create_async_client( + dsn="http://localhost:8123/default?headers=not_a_dict", + headers={"X-Gateway": "cloudflare"}, + ) + + assert client.headers["X-Gateway"] == "cloudflare" + + class TestHttpClientErrorHandler: """Test the error handling functionality of HttpClient"""