Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/schematic/datastream/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@
MIN_RECONNECT_DELAY = 1.0 # seconds
MAX_RECONNECT_DELAY = 30.0 # seconds

# Headers attached to the WebSocket handshake so the backend can distinguish
# direct-SDK connections from the schematic-datastream-replicator and correlate
# either to a specific release. Mode is always "direct" here — replicator mode
# in this SDK doesn't open a WebSocket at all.
CLIENT_NAME = "schematic-python"
DATASTREAM_MODE_DIRECT = "direct"
UNKNOWN_VERSION = "unknown"


def _get_sdk_version() -> str:
"""Return the installed schematichq package version, or "unknown"."""
try:
from importlib import metadata

return metadata.version("schematichq")
except Exception:
return UNKNOWN_VERSION

MessageHandlerFunc = Callable[[DataStreamResp], Awaitable[None]]
ConnectionReadyHandlerFunc = Callable[[], Awaitable[None]]

Expand Down Expand Up @@ -135,7 +153,12 @@ def __init__(self, options: ClientOptions) -> None:
else:
self._url = options.url

self._headers: Dict[str, str] = {"X-Schematic-Api-Key": options.api_key}
self._headers: Dict[str, str] = {
"X-Schematic-Api-Key": options.api_key,
"X-Schematic-Datastream-Mode": DATASTREAM_MODE_DIRECT,
"X-Schematic-Client": CLIENT_NAME,
"X-Schematic-Client-Version": _get_sdk_version(),
}
self._logger = options.logger
self._message_handler = options.message_handler
self._connection_ready_handler = options.connection_ready_handler
Expand Down
61 changes: 61 additions & 0 deletions tests/datastream/test_websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from schematic.datastream.types import DataStreamBaseReq, DataStreamReq, DataStreamResp, EntityType
from schematic.datastream.websocket_client import (
_WS_HEADERS_KWARG,
ClientOptions,
DatastreamWSClient,
convert_api_url_to_websocket_url,
Expand Down Expand Up @@ -194,6 +195,66 @@ async def handler(msg): pass
assert client._url == "wss://datastream.schematichq.com/datastream"


# ---------------------------------------------------------------------------
# Handshake headers
# ---------------------------------------------------------------------------


def test_handshake_headers_include_identification() -> None:
"""The handshake carries mode/client/version headers so the backend can
distinguish direct-SDK connections from schematic-datastream-replicator
and correlate them to a release."""
async def handler(msg): pass

client = DatastreamWSClient(
ClientOptions(url="wss://example.com", api_key="my-key", message_handler=handler, logger=logger)
)

assert client._headers["X-Schematic-Api-Key"] == "my-key"
assert client._headers["X-Schematic-Datastream-Mode"] == "direct"
assert client._headers["X-Schematic-Client"] == "schematic-python"
assert "X-Schematic-Client-Version" in client._headers
assert client._headers["X-Schematic-Client-Version"] != ""


async def test_handshake_headers_passed_to_websockets_connect() -> None:
"""Confirms the identification headers actually reach websockets.connect."""
captured_kwargs: dict = {}
ws = MockWebSocket(block_on_empty=True)

@asynccontextmanager
async def capturing_connect(*args, **kwargs):
captured_kwargs.update(kwargs)
yield ws

connected = asyncio.Event()
client, ws, _ = make_client(ws=ws, on_connected=lambda: connected.set())

with patch("schematic.datastream.websocket_client.websockets.connect", capturing_connect):
async with run_client(client):
await asyncio.wait_for(connected.wait(), timeout=2.0)

headers = captured_kwargs.get(_WS_HEADERS_KWARG)
assert headers is not None
assert headers["X-Schematic-Api-Key"] == "test-key"
assert headers["X-Schematic-Datastream-Mode"] == "direct"
assert headers["X-Schematic-Client"] == "schematic-python"
assert headers.get("X-Schematic-Client-Version")


def test_get_sdk_version_falls_back_to_unknown_when_metadata_missing() -> None:
"""When importlib.metadata can't resolve the package (e.g. running from
an uninstalled checkout), the helper returns "unknown" rather than raising.
Matches schematic-go's behaviour for untagged builds."""
from importlib.metadata import PackageNotFoundError

from schematic.datastream import websocket_client

with patch.object(websocket_client, "_get_sdk_version", wraps=websocket_client._get_sdk_version):
with patch("importlib.metadata.version", side_effect=PackageNotFoundError("schematichq")):
assert websocket_client._get_sdk_version() == "unknown"


# ---------------------------------------------------------------------------
# send_message
# ---------------------------------------------------------------------------
Expand Down
Loading