diff --git a/src/schematic/client.py b/src/schematic/client.py index 673d248..7795ec0 100644 --- a/src/schematic/client.py +++ b/src/schematic/client.py @@ -1,4 +1,5 @@ import atexit +import datetime as dt import logging from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union @@ -38,6 +39,66 @@ class CheckFlagOptions: default_value: Optional[Union[bool, Callable[[], bool]]] = None timeout: Optional[float] = None + # Client-supplied dedupe key for the resulting flag_check event. Only + # applied when the SDK evaluates the flag locally via DataStream and + # fires its own flag_check event (the REST API path sets its own). + # Duplicate events with the same key are dropped server-side for 24h. + idempotency_key: Optional[str] = None + + +@dataclass +class TrackOptions: + """Optional metadata for a track event. + + Fields map directly to the corresponding ``CreateEventRequestBody`` + properties. Omit any field you don't need; the SDK only sends fields + that are explicitly set. + """ + + # Client-supplied dedupe key. Duplicate events with the same key + # (scoped to the environment) are dropped server-side for 24 hours. + idempotency_key: Optional[str] = None + # Timestamp the event was sent. Required when trusted_client_clock=True. + sent_at: Optional[dt.datetime] = None + # When True, use sent_at as the effective event timestamp instead of + # server receipt time. Requires a secret API key and sent_at. + trusted_client_clock: Optional[bool] = None + # Import historical data without affecting billing. Requires a secret + # API key and trusted_client_clock. + backfill: Optional[bool] = None + + +@dataclass +class IdentifyOptions: + """Optional metadata for an identify event. + + Fields map directly to the corresponding ``CreateEventRequestBody`` + properties. Omit any field you don't need; the SDK only sends fields + that are explicitly set. + """ + + # Client-supplied dedupe key. Duplicate events with the same key + # (scoped to the environment) are dropped server-side for 24 hours. + idempotency_key: Optional[str] = None + + +def _event_options_to_kwargs( + options: Optional[Union[TrackOptions, IdentifyOptions, CheckFlagOptions]], +) -> Dict[str, Any]: + """Flatten an options dataclass into kwargs for CreateEventRequestBody. + + Only fields that were explicitly set on the dataclass are returned, so + unset fields don't override CreateEventRequestBody's own defaults and + don't appear on the wire as explicit nulls. + """ + if options is None: + return {} + kwargs: Dict[str, Any] = {} + for field in ("idempotency_key", "sent_at", "trusted_client_clock", "backfill"): + value = getattr(options, field, None) + if value is not None: + kwargs[field] = value + return kwargs @dataclass @@ -274,6 +335,7 @@ def identify( company: Optional[EventBodyIdentifyCompany] = None, name: Optional[str] = None, traits: Optional[Dict[str, Any]] = None, + options: Optional[IdentifyOptions] = None, ) -> None: self._enqueue_event( "identify", @@ -283,6 +345,7 @@ def identify( name=name, traits=traits, ), + options=options, ) def track( @@ -292,6 +355,7 @@ def track( user: Optional[Dict[str, str]] = None, traits: Optional[Dict[str, Any]] = None, quantity: Optional[int] = None, + options: Optional[TrackOptions] = None, ) -> None: self._enqueue_event( "track", @@ -302,13 +366,23 @@ def track( traits=traits, user=user, ), + options=options, ) - def _enqueue_event(self, event_type: str, body: EventBody) -> None: + def _enqueue_event( + self, + event_type: str, + body: EventBody, + options: Optional[Union[TrackOptions, IdentifyOptions, CheckFlagOptions]] = None, + ) -> None: if self.offline: return try: - event_body = CreateEventRequestBody(event_type=event_type, body=body) + event_body = CreateEventRequestBody( + event_type=event_type, + body=body, + **_event_options_to_kwargs(options), + ) self.event_buffer.push(event_body) except Exception as e: self.logger.error(e) @@ -492,7 +566,7 @@ async def check_flag_with_entitlement( CheckFlagRequestBody(company=company, user=user), flag_key, ) - await self._enqueue_flag_check_event(flag_key, resp, company, user) + await self._enqueue_flag_check_event(flag_key, resp, company, user, options) return self._ds_result_to_response(flag_key, resp, options) except Exception as e: self.logger.debug(f"Datastream flag check failed ({e}), falling back to API") @@ -627,6 +701,7 @@ async def _enqueue_flag_check_event( resp: RulesengineCheckFlagResult, company: Optional[Dict[str, str]], user: Optional[Dict[str, str]], + options: Optional[CheckFlagOptions] = None, ) -> None: """Enqueue a flag_check event for a DataStream-evaluated flag.""" await self._enqueue_event( @@ -642,6 +717,7 @@ async def _enqueue_flag_check_event( req_company=company, req_user=user, ), + options=options, ) def _ds_result_to_response( @@ -700,6 +776,7 @@ async def identify( company: Optional[EventBodyIdentifyCompany] = None, name: Optional[str] = None, traits: Optional[Dict[str, Any]] = None, + options: Optional[IdentifyOptions] = None, ) -> None: await self._enqueue_event( "identify", @@ -709,6 +786,7 @@ async def identify( name=name, traits=traits, ), + options=options, ) async def track( @@ -718,6 +796,7 @@ async def track( user: Optional[Dict[str, str]] = None, traits: Optional[Dict[str, Any]] = None, quantity: Optional[int] = None, + options: Optional[TrackOptions] = None, ) -> None: await self._enqueue_event( "track", @@ -728,6 +807,7 @@ async def track( traits=traits, user=user, ), + options=options, ) # Update company metrics in DataStream if available and connected @@ -742,11 +822,20 @@ async def track( except Exception as e: self.logger.error(f"Failed to update company metrics: {e}") - async def _enqueue_event(self, event_type: str, body: EventBody) -> None: + async def _enqueue_event( + self, + event_type: str, + body: EventBody, + options: Optional[Union[TrackOptions, IdentifyOptions, CheckFlagOptions]] = None, + ) -> None: if self.offline: return try: - event_body = CreateEventRequestBody(event_type=event_type, body=body) + event_body = CreateEventRequestBody( + event_type=event_type, + body=body, + **_event_options_to_kwargs(options), + ) await self.event_buffer.push(event_body) except Exception as e: self.logger.error(e) diff --git a/src/schematic/event_capture.py b/src/schematic/event_capture.py index 8f7196f..bc9385b 100644 --- a/src/schematic/event_capture.py +++ b/src/schematic/event_capture.py @@ -14,13 +14,18 @@ class _CaptureEventPayload(UniversalBaseModel): """Wire format for a single event sent to the capture service. Mirrors the shape used by the Go/Ruby/C# SDKs: `type` (not `event_type`) - and an `api_key` field embedded on each event. + and an `api_key` field embedded on each event. The optional metadata + fields (idempotency_key, sent_at, trusted_client_clock, backfill) map + directly to the equivalent fields on ``CreateEventRequestBody``. """ api_key: str = pydantic.Field() body: typing.Optional[EventBody] = None type: EventType = pydantic.Field() + idempotency_key: typing.Optional[str] = None sent_at: typing.Optional[dt.datetime] = None + trusted_client_clock: typing.Optional[bool] = None + backfill: typing.Optional[bool] = None class _CaptureBatchPayload(UniversalBaseModel): @@ -28,12 +33,25 @@ class _CaptureBatchPayload(UniversalBaseModel): def _to_payload(event: CreateEventRequestBody, api_key: str) -> _CaptureEventPayload: - return _CaptureEventPayload( - api_key=api_key, - body=event.body, - type=event.event_type, - sent_at=event.sent_at, - ) + # Build kwargs conditionally so unset optional fields stay unset on the + # model. The capture wire format uses `exclude_unset`-style semantics — + # we don't want to send `"idempotency_key": null` for events that didn't + # set one. + kwargs: typing.Dict[str, typing.Any] = { + "api_key": api_key, + "type": event.event_type, + } + if event.body is not None: + kwargs["body"] = event.body + if event.idempotency_key is not None: + kwargs["idempotency_key"] = event.idempotency_key + if event.sent_at is not None: + kwargs["sent_at"] = event.sent_at + if event.trusted_client_clock is not None: + kwargs["trusted_client_clock"] = event.trusted_client_clock + if event.backfill is not None: + kwargs["backfill"] = event.backfill + return _CaptureEventPayload(**kwargs) def _build_endpoint(base_url: str) -> str: diff --git a/tests/custom/test_client.py b/tests/custom/test_client.py index 592a944..e5c899d 100644 --- a/tests/custom/test_client.py +++ b/tests/custom/test_client.py @@ -12,8 +12,10 @@ AsyncSchematic, AsyncSchematicConfig, CheckFlagOptions, + IdentifyOptions, Schematic, SchematicConfig, + TrackOptions, ) from schematic.types import CheckFlagResponseData, FeatureEntitlement @@ -159,6 +161,87 @@ def test_track_with_quantity(self): ) mock_push.assert_called_once() + def test_track_with_idempotency_key(self): + """idempotency_key set via TrackOptions must land on the + CreateEventRequestBody pushed to the event buffer so the server can + dedupe on it.""" + with patch.object(self.schematic.event_buffer, "push") as mock_push: + self.schematic.track( + event="credit-consumed", + company={"id": "company_id"}, + options=TrackOptions(idempotency_key="evt_abc123"), + ) + mock_push.assert_called_once() + pushed = mock_push.call_args.args[0] + self.assertEqual(pushed.idempotency_key, "evt_abc123") + + def test_track_without_options_leaves_optional_fields_none(self): + """Options are opt-in — omitting `options` must leave every optional + metadata field at None on the wire.""" + with patch.object(self.schematic.event_buffer, "push") as mock_push: + self.schematic.track( + event="some-event", + company={"id": "company_id"}, + ) + pushed = mock_push.call_args.args[0] + self.assertIsNone(pushed.idempotency_key) + self.assertIsNone(pushed.sent_at) + self.assertIsNone(pushed.trusted_client_clock) + self.assertIsNone(pushed.backfill) + + def test_track_with_full_options(self): + """Every TrackOptions field should land on the CreateEventRequestBody.""" + import datetime as dt + sent_at = dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt.timezone.utc) + with patch.object(self.schematic.event_buffer, "push") as mock_push: + self.schematic.track( + event="historical-import", + company={"id": "company_id"}, + options=TrackOptions( + idempotency_key="evt_xyz", + sent_at=sent_at, + trusted_client_clock=True, + backfill=True, + ), + ) + pushed = mock_push.call_args.args[0] + self.assertEqual(pushed.idempotency_key, "evt_xyz") + self.assertEqual(pushed.sent_at, sent_at) + self.assertTrue(pushed.trusted_client_clock) + self.assertTrue(pushed.backfill) + + def test_track_partial_options(self): + """Unset TrackOptions fields stay None on the CreateEventRequestBody — + we don't accidentally send explicit nulls for things the caller didn't ask for.""" + with patch.object(self.schematic.event_buffer, "push") as mock_push: + self.schematic.track( + event="some-event", + company={"id": "company_id"}, + options=TrackOptions(idempotency_key="just-the-key"), + ) + pushed = mock_push.call_args.args[0] + self.assertEqual(pushed.idempotency_key, "just-the-key") + self.assertIsNone(pushed.sent_at) + self.assertIsNone(pushed.trusted_client_clock) + self.assertIsNone(pushed.backfill) + + def test_identify_with_options(self): + """IdentifyOptions must plumb through to the CreateEventRequestBody.""" + with patch.object(self.schematic.event_buffer, "push") as mock_push: + self.schematic.identify( + keys={"id": "user_id"}, + options=IdentifyOptions(idempotency_key="ident_123"), + ) + pushed = mock_push.call_args.args[0] + self.assertEqual(pushed.idempotency_key, "ident_123") + + def test_identify_without_options(self): + """Existing identify callers without options keep working unchanged.""" + with patch.object(self.schematic.event_buffer, "push") as mock_push: + self.schematic.identify(keys={"id": "user_id"}, name="User Name") + pushed = mock_push.call_args.args[0] + self.assertIsNone(pushed.idempotency_key) + def test_check_flag_with_no_cache(self): """Verify that when cache_providers is empty, every call hits the API.""" config = SchematicConfig( @@ -766,6 +849,40 @@ async def test_track(self): ) mock_push.assert_called_once() + async def test_track_with_options(self): + """All TrackOptions fields must plumb through async track() to the + CreateEventRequestBody.""" + import datetime as dt + sent_at = dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt.timezone.utc) + with patch.object(self.async_schematic.event_buffer, "push") as mock_push: + await self.async_schematic.track( + event="credit-consumed", + company={"id": "company_id"}, + options=TrackOptions( + idempotency_key="evt_abc123", + sent_at=sent_at, + trusted_client_clock=True, + backfill=False, + ), + ) + mock_push.assert_called_once() + pushed = mock_push.call_args.args[0] + assert pushed.idempotency_key == "evt_abc123" + assert pushed.sent_at == sent_at + assert pushed.trusted_client_clock is True + # backfill=False is explicitly set; it should land on the body. + assert pushed.backfill is False + + async def test_async_identify_with_options(self): + """IdentifyOptions must plumb through async identify().""" + with patch.object(self.async_schematic.event_buffer, "push") as mock_push: + await self.async_schematic.identify( + keys={"id": "user_id"}, + options=IdentifyOptions(idempotency_key="ident_async"), + ) + pushed = mock_push.call_args.args[0] + assert pushed.idempotency_key == "ident_async" + async def test_check_flag_with_no_cache(self): """Verify that when cache_providers is empty, every call hits the API.""" config = AsyncSchematicConfig( @@ -1050,6 +1167,79 @@ async def test_check_flag_datastream_local_evaluation_skips_api(self): finally: await client.event_buffer.stop() + async def test_check_flag_via_datastream_propagates_idempotency_key_to_flag_check_event(self): + """CheckFlagOptions.idempotency_key must land on the flag_check + CreateEventRequestBody pushed to the buffer when the SDK evaluates + the flag locally via DataStream — same dedupe contract as for + track/identify events.""" + from schematic.types import RulesengineCheckFlagResult + + config = AsyncSchematicConfig( + logger=MagicMock(), + httpx_client=MagicMock(spec=AsyncClient), + event_buffer_period=1, + use_datastream=True, + ) + client = AsyncSchematic("test_key", config) + try: + ds_result = RulesengineCheckFlagResult( + value=True, + flag_key="test_flag", + flag_id="flag-1", + reason="match", + rule_id="rule-1", + rule_type="override", + company_id="comp-1", + ) + mock_ds = MagicMock() + mock_ds.check_flag = AsyncMock(return_value=ds_result) + client._datastream_client = mock_ds + + with patch.object(client.event_buffer, "push") as mock_push: + await client.check_flag( + "test_flag", + company={"id": "comp-1"}, + options=CheckFlagOptions(idempotency_key="flag_check_evt_42"), + ) + + mock_push.assert_called_once() + pushed = mock_push.call_args.args[0] + assert pushed.event_type == "flag_check" + assert pushed.idempotency_key == "flag_check_evt_42" + finally: + await client.event_buffer.stop() + + async def test_check_flag_via_datastream_no_idempotency_key_leaves_field_none(self): + """Regression: when no idempotency_key is supplied, the flag_check + event still goes out cleanly with the field unset.""" + from schematic.types import RulesengineCheckFlagResult + + config = AsyncSchematicConfig( + logger=MagicMock(), + httpx_client=MagicMock(spec=AsyncClient), + event_buffer_period=1, + use_datastream=True, + ) + client = AsyncSchematic("test_key", config) + try: + ds_result = RulesengineCheckFlagResult( + value=True, flag_key="test_flag", flag_id="flag-1", + reason="match", rule_id="rule-1", rule_type="override", + company_id="comp-1", + ) + mock_ds = MagicMock() + mock_ds.check_flag = AsyncMock(return_value=ds_result) + client._datastream_client = mock_ds + + with patch.object(client.event_buffer, "push") as mock_push: + await client.check_flag("test_flag", company={"id": "comp-1"}) + + pushed = mock_push.call_args.args[0] + assert pushed.event_type == "flag_check" + assert pushed.idempotency_key is None + finally: + await client.event_buffer.stop() + async def test_check_flag_falls_back_to_api_when_flag_not_in_datastream_cache(self): """Spec checklist item 9 (DataStream): when the requested flag is not cached locally by the DataStream client, the wrapper must fall back diff --git a/tests/custom/test_event_capture.py b/tests/custom/test_event_capture.py new file mode 100644 index 0000000..1bd80ac --- /dev/null +++ b/tests/custom/test_event_capture.py @@ -0,0 +1,143 @@ +"""Tests for the event capture wire-format mapping. + +The capture service expects a specific JSON shape (api_key + type + optional +metadata) that's different from the Fern-generated CreateEventRequestBody. +These tests pin the mapping so optional fields like idempotency_key don't +silently get dropped on the way to the wire. +""" +from __future__ import annotations + +import datetime as dt +import json + +from schematic.event_capture import ( + _CaptureBatchPayload, + _CaptureEventPayload, + _serialize_batch, + _to_payload, +) +from schematic.types import CreateEventRequestBody, EventBodyTrack + + +def _make_event(**overrides) -> CreateEventRequestBody: + """Build a track event with arbitrary CreateEventRequestBody overrides.""" + return CreateEventRequestBody( + event_type="track", + body=EventBodyTrack( + event="some-event", + company={"id": "co_123"}, + ), + **overrides, + ) + + +class TestToPayloadMapping: + """_to_payload must copy every optional metadata field from + CreateEventRequestBody onto _CaptureEventPayload, so the capture service + receives values the SDK consumer set.""" + + def test_minimum_required_fields(self) -> None: + event = _make_event() + payload = _to_payload(event, api_key="sch_test") + + assert payload.api_key == "sch_test" + assert payload.type == "track" + assert payload.body is not None + assert payload.idempotency_key is None + assert payload.sent_at is None + assert payload.trusted_client_clock is None + assert payload.backfill is None + + def test_idempotency_key_mapped(self) -> None: + event = _make_event(idempotency_key="evt_abc123") + payload = _to_payload(event, api_key="sch_test") + assert payload.idempotency_key == "evt_abc123" + + def test_sent_at_mapped(self) -> None: + sent = dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt.timezone.utc) + event = _make_event(sent_at=sent) + payload = _to_payload(event, api_key="sch_test") + assert payload.sent_at == sent + + def test_trusted_client_clock_mapped(self) -> None: + event = _make_event(trusted_client_clock=True) + payload = _to_payload(event, api_key="sch_test") + assert payload.trusted_client_clock is True + + def test_backfill_mapped(self) -> None: + event = _make_event(backfill=True) + payload = _to_payload(event, api_key="sch_test") + assert payload.backfill is True + + def test_all_optional_fields_mapped_together(self) -> None: + sent = dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt.timezone.utc) + event = _make_event( + idempotency_key="evt_xyz", + sent_at=sent, + trusted_client_clock=True, + backfill=True, + ) + payload = _to_payload(event, api_key="sch_test") + assert payload.idempotency_key == "evt_xyz" + assert payload.sent_at == sent + assert payload.trusted_client_clock is True + assert payload.backfill is True + + +class TestSerializeBatch: + """The serialized JSON sent to the capture service must include the + optional metadata fields when set, and must exclude them (rather than + sending explicit nulls) when unset.""" + + def test_unset_optional_fields_excluded_from_wire(self) -> None: + event = _make_event() + body = _serialize_batch([event], api_key="sch_test") + data = json.loads(body) + + wire_event = data["events"][0] + assert wire_event["api_key"] == "sch_test" + assert wire_event["type"] == "track" + # Unset fields should not appear at all — exclude_none on the model + # ensures we don't send `"idempotency_key": null` and friends. + assert "idempotency_key" not in wire_event + assert "sent_at" not in wire_event + assert "trusted_client_clock" not in wire_event + assert "backfill" not in wire_event + + def test_set_optional_fields_appear_on_wire(self) -> None: + sent = dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt.timezone.utc) + event = _make_event( + idempotency_key="evt_xyz", + sent_at=sent, + trusted_client_clock=True, + backfill=False, + ) + body = _serialize_batch([event], api_key="sch_test") + data = json.loads(body) + + wire_event = data["events"][0] + assert wire_event["idempotency_key"] == "evt_xyz" + assert wire_event["trusted_client_clock"] is True + # backfill=False is explicitly set, so it must reach the wire even + # though the value is falsy. + assert wire_event["backfill"] is False + assert "sent_at" in wire_event + + +class TestCapturePayloadShape: + """Pin the wire field names so the capture service contract doesn't + silently drift if someone renames a Pydantic field.""" + + def test_uses_type_not_event_type(self) -> None: + """The capture service expects `type` (matching Go/Ruby/C# SDKs), + not `event_type` (which is the REST API name).""" + payload = _CaptureEventPayload(api_key="k", type="track") + dumped = payload.model_dump() + assert "type" in dumped + assert "event_type" not in dumped + + def test_batch_wrapper_uses_events_field(self) -> None: + batch = _CaptureBatchPayload(events=[]) + dumped = batch.model_dump() + assert "events" in dumped + assert dumped["events"] == []