Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Improvements
- SQLAlchemy: opt-in server-side bind parameters via `create_engine(url, server_side_params=True)`. The dialect then emits ClickHouse native `{name:Type}` / `{name:Array(Type)}` placeholders instead of client-side string interpolation. Off by default. Closes [#735](https://github.com/ClickHouse/clickhouse-connect/issues/735).
- Added a `token_provider` client option (sync and async). It accepts a callable returning an access token string; the callable is invoked once for the initial token and again to fetch a fresh token whenever the server rejects the current one (authentication failure), retrying the request once. Mutually exclusive with `access_token` and `username`/`password`.
- Added a `headers` option to `create_client`/`create_async_client` for attaching custom HTTP headers to every request, including the initialization queries sent during client creation. Useful for HTTP gateways that require auth headers such as Cloudflare Access service tokens.

### Bug Fixes
- A `datetime` bound to a server-side `{name:DateTime64(...)}` placeholder now keeps its sub-second precision instead of being truncated to seconds. The declared parameter type drives this, so no `_64` name suffix or manual `DT64Param` wrapper is needed, and it applies through `Array` and `Tuple` hints. Plain `DateTime` binds are unchanged. Closes [#739](https://github.com/ClickHouse/clickhouse-connect/issues/739).
Expand Down
39 changes: 36 additions & 3 deletions clickhouse_connect/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def _validate_access_token(access_token: str | None, token_provider: Callable[[]
raise ProgrammingError("Cannot use both access_token and token_provider")


def _pop_headers_arg(headers: Any | None, kwargs: dict[str, Any]) -> Any | None:
"""Hoist headers parsed through generic kwargs while preserving explicit headers."""
if "headers" in kwargs:
kwargs_headers = kwargs.pop("headers")
if headers is None:
headers = kwargs_headers
return headers


def _validate_headers(headers: Any | None) -> None:
if headers is not None and not isinstance(headers, dict):
raise ProgrammingError("headers must be a dictionary of HTTP header names and values")


def create_client(
*,
host: str | None = None,
Expand All @@ -102,6 +116,7 @@ def create_client(
secure: bool | str = False,
dsn: str | None = None,
settings: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
generic_args: dict[str, Any] | None = None,
**kwargs,
) -> Client:
Expand All @@ -127,6 +142,9 @@ def create_client(
:param dsn: A string in standard DSN (Data Source Name) format. Other connection values (such as host or user)
will be extracted from this string if not set otherwise.
:param settings: ClickHouse server settings to be used with the session/every request
:param headers: Additional HTTP headers to send with every request. This can be used for proxy or gateway
authentication, such as Cloudflare Access service token headers. These headers are applied after driver defaults,
so they can intentionally override headers such as Authorization or User-Agent.
:param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings.
It is not recommended to use this parameter externally.

Expand Down Expand Up @@ -174,14 +192,18 @@ def create_client(
host, username, password, port, database, interface = _parse_connection_params(
host, username, password, port, database, interface, secure, dsn, kwargs
)
headers = _pop_headers_arg(headers, kwargs)
_validate_access_token(access_token, token_provider, username, password)

settings = settings or {}
if interface.startswith("http"):
if generic_args:
client_params = signature(HttpClient).parameters
for name, value in generic_args.items():
if name in client_params:
if name == "headers":
if headers is None:
headers = value
elif name in client_params:
kwargs[name] = value
elif name == "compression":
if "compress" not in kwargs:
Expand All @@ -196,6 +218,7 @@ def create_client(
access_token = access_token or generic_access
token_provider = token_provider or generic_token
_validate_access_token(access_token, token_provider, username, password)
_validate_headers(headers)
return HttpClient(
interface,
host,
Expand All @@ -206,6 +229,7 @@ def create_client(
access_token,
token_provider=token_provider,
settings=settings,
headers=headers,
**kwargs,
)
Comment thread
joe-clickhouse marked this conversation as resolved.
raise ProgrammingError(f"Unrecognized client type {interface}")
Expand All @@ -224,6 +248,7 @@ async def create_async_client(
secure: bool | str = False,
dsn: str | None = None,
settings: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
generic_args: dict[str, Any] | None = None,
connector_limit: int = 100,
connector_limit_per_host: int = 20,
Expand Down Expand Up @@ -254,6 +279,9 @@ async def create_async_client(
:param dsn: A string in standard DSN (Data Source Name) format. Other connection values (such as host or user)
will be extracted from this string if not set otherwise.
:param settings: ClickHouse server settings to be used with the session/every request
:param headers: Additional HTTP headers to send with every request. This can be used for proxy or gateway
authentication, such as Cloudflare Access service token headers. These headers are applied after driver defaults,
so they can intentionally override headers such as Authorization or User-Agent.
:param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings.
It is not recommended to use this parameter externally
:param connector_limit: Maximum number of allowable connections to the server
Expand Down Expand Up @@ -313,13 +341,17 @@ async def create_async_client(
host, username, password, port, database, interface = _parse_connection_params(
host, username, password, port, database, interface, secure, dsn, kwargs
)
headers = _pop_headers_arg(headers, kwargs)
_validate_access_token(access_token, token_provider, username, password)

settings = settings or {}
if generic_args:
client_params = signature(_AsyncClient).parameters
for name, value in generic_args.items():
if name in client_params:
if name == "headers":
if headers is None:
headers = value
elif name in client_params:
kwargs[name] = value
elif name == "compression":
if "compress" not in kwargs:
Expand All @@ -338,7 +370,7 @@ async def create_async_client(
access_token = access_token or generic_access
token_provider = token_provider or generic_token
_validate_access_token(access_token, token_provider, username, password)

_validate_headers(headers)
client = _AsyncClient(
interface=interface,
host=host,
Expand All @@ -349,6 +381,7 @@ async def create_async_client(
access_token=access_token,
token_provider=token_provider,
settings=settings,
headers=headers,
connector_limit=connector_limit,
connector_limit_per_host=connector_limit_per_host,
keepalive_timeout=keepalive_timeout,
Expand Down
3 changes: 3 additions & 0 deletions clickhouse_connect/driver/asyncclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(
autogenerate_query_id: bool | None = None,
form_encode_query_params: bool = False,
rename_response_column: str | None = None,
headers: dict[str, str] | None = None,
):
"""
Async HTTP Client using aiohttp. Initialization is handled via _initialize().
Expand Down Expand Up @@ -334,6 +335,8 @@ def __init__(
self._reported_libs = set()
self._last_pool_reset = None
self.headers["User-Agent"] = self.headers["User-Agent"].replace("mode:sync;", "mode:async;")
if headers:
self.headers.update(headers)

# Store aiohttp-specific params for deferred initialization
self._compress_param = compress
Expand Down
10 changes: 9 additions & 1 deletion clickhouse_connect/driver/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
proxy_path: str = "",
form_encode_query_params: bool = False,
rename_response_column: str | None = None,
headers: dict[str, str] | None = None,
):
"""
Create an HTTP ClickHouse Connect client
Expand Down Expand Up @@ -162,6 +163,8 @@ def __init__(

self._reported_libs = set()
self.headers["User-Agent"] = common.build_client_name(client_name)
if headers:
self.headers.update(headers)
self._read_format = self._write_format = "Native"
self._transform = NativeTransform()

Expand Down Expand Up @@ -756,7 +759,12 @@ def ping(self) -> bool:
See BaseClient doc_string for this method
"""
try:
response = self.http.request("GET", f"{self.url}/ping", timeout=3, preload_content=True)
headers = dict_copy(self.headers)
kwargs = {"headers": headers, "timeout": 3, "preload_content": True}
if self.server_host_name:
kwargs["assert_same_host"] = False
headers["Host"] = self.server_host_name
response = self.http.request("GET", f"{self.url}/ping", **kwargs)
return 200 <= response.status < 300
except HTTPError:
logger.debug("ping failed", exc_info=True)
Expand Down
13 changes: 13 additions & 0 deletions tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ def test_transport_settings(param_client, call):
assert len(result.result_set) > 0


def test_client_headers(client_factory, call):
client = client_factory(
headers={
"CF-Access-Client-Id": "test_client_id",
"CF-Access-Client-Secret": "test_client_secret",
}
)

assert client.headers["CF-Access-Client-Id"] == "test_client_id"
assert client.headers["CF-Access-Client-Secret"] == "test_client_secret"
assert call(client.command, "SELECT 79") == 79


def test_none_database(param_client, call):
old_db = param_client.database
test_db = call(param_client.command, "select currentDatabase()")
Expand Down
153 changes: 151 additions & 2 deletions tests/unit_tests/test_driver/test_httpclient.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from typing import Any
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest

from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError
from clickhouse_connect.driver import create_async_client, create_client
from clickhouse_connect.driver.asyncclient import AsyncClient
from clickhouse_connect.driver.client import Client
from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError
from clickhouse_connect.driver.external import ExternalData
from clickhouse_connect.driver.httpclient import HttpClient, ex_header
from clickhouse_connect.driver.query import QueryContext
Expand All @@ -21,6 +24,152 @@ def create_mock_response(status=500, headers=None, data=None):
return response


class TestHttpClientHeaders:
"""Test client-level HTTP header configuration."""

def test_headers_are_available_during_initialization(self):
init_headers = {}

def capture_headers(client, _tz_source):
init_headers.update(client.headers)

with patch.object(Client, "_init_common_settings", autospec=True, side_effect=capture_headers):
HttpClient(
interface="http",
host="localhost",
port=8123,
username="default",
password="",
database="default",
headers={
"CF-Access-Client-Id": "test_client_id",
"CF-Access-Client-Secret": "test_client_secret",
},
)

assert init_headers["CF-Access-Client-Id"] == "test_client_id"
assert init_headers["CF-Access-Client-Secret"] == "test_client_secret"
assert "Authorization" in init_headers
assert "User-Agent" in init_headers

def test_request_headers_override_client_headers(self):
response = create_mock_response(status=200)
pool_mgr = Mock()
pool_mgr.request.return_value = response

with patch.object(Client, "_init_common_settings", autospec=True):
client = HttpClient(
interface="http",
host="localhost",
port=8123,
username="default",
password="",
database="default",
pool_mgr=pool_mgr,
headers={"X-Trace": "client", "X-Gateway": "cloudflare"},
)

client._raw_request(b"", {}, headers={"X-Trace": "request"})

request_headers = pool_mgr.request.call_args.kwargs["headers"]
assert request_headers["X-Trace"] == "request"
assert request_headers["X-Gateway"] == "cloudflare"
assert request_headers["Authorization"] == client.headers["Authorization"]
assert request_headers["User-Agent"] == client.headers["User-Agent"]

def test_ping_uses_client_headers(self):
response = create_mock_response(status=200)
pool_mgr = Mock()
pool_mgr.request.return_value = response

with patch.object(Client, "_init_common_settings", autospec=True):
client = HttpClient(
interface="http",
host="localhost",
port=8123,
username="default",
password="",
database="default",
pool_mgr=pool_mgr,
server_host_name="clickhouse.example.com",
headers={"X-Gateway": "cloudflare"},
)

assert client.ping() is True

request_headers = pool_mgr.request.call_args.kwargs["headers"]
assert request_headers["X-Gateway"] == "cloudflare"
assert request_headers["Authorization"] == client.headers["Authorization"]
assert request_headers["User-Agent"] == client.headers["User-Agent"]
assert request_headers["Host"] == "clickhouse.example.com"
assert pool_mgr.request.call_args.kwargs["assert_same_host"] is False

def test_dsn_headers_query_param_must_be_dict(self):
with pytest.raises(ProgrammingError, match="headers must be a dictionary"):
create_client(dsn="http://localhost:8123/default?headers=not_a_dict")

def test_explicit_headers_override_dsn_headers_query_param(self):
init_headers = {}

def capture_headers(client, _tz_source):
init_headers.update(client.headers)

with patch.object(Client, "_init_common_settings", autospec=True, side_effect=capture_headers):
create_client(
dsn="http://localhost:8123/default?headers=not_a_dict",
headers={"X-Gateway": "cloudflare"},
)

assert init_headers["X-Gateway"] == "cloudflare"


class TestAsyncClientHeaders:
"""Test async client-level HTTP header configuration."""

@pytest.mark.asyncio
async def test_request_headers_override_client_headers(self):
client = AsyncClient(
interface="http",
host="localhost",
port=8123,
username="default",
password="",
database="default",
headers={"X-Trace": "client", "X-Gateway": "cloudflare"},
)
session = Mock()
session.closed = False
response = Mock()
response.status = 200
response.headers = {}
session.request = AsyncMock(return_value=response)
client._session = session

await client._raw_request(None, {}, headers={"X-Trace": "request"})

request_headers = session.request.call_args.kwargs["headers"]
assert request_headers["X-Trace"] == "request"
assert request_headers["X-Gateway"] == "cloudflare"
assert request_headers["Authorization"] == client.headers["Authorization"]
assert request_headers["User-Agent"] == client.headers["User-Agent"]
assert request_headers["Accept-Encoding"] == client.headers["Accept-Encoding"]

@pytest.mark.asyncio
async def test_dsn_headers_query_param_must_be_dict(self):
with pytest.raises(ProgrammingError, match="headers must be a dictionary"):
await create_async_client(dsn="http://localhost:8123/default?headers=not_a_dict")

@pytest.mark.asyncio
async def test_explicit_headers_override_dsn_headers_query_param(self):
with patch.object(AsyncClient, "_initialize", new=AsyncMock()):
client = await create_async_client(
dsn="http://localhost:8123/default?headers=not_a_dict",
headers={"X-Gateway": "cloudflare"},
)

assert client.headers["X-Gateway"] == "cloudflare"


class TestHttpClientErrorHandler:
"""Test the error handling functionality of HttpClient"""

Expand Down