diff --git a/agent/core/prompt_caching.py b/agent/core/prompt_caching.py index 04304bcd..38c58068 100644 --- a/agent/core/prompt_caching.py +++ b/agent/core/prompt_caching.py @@ -1,8 +1,11 @@ """Prompt-cache helpers for HF Router FAL requests. The HF Router/OpenRouter path uses provider-native prompt caching. Anthropic -models need explicit JSON ``cache_control`` content blocks; OpenAI models cache -eligible prefixes automatically and accept routing/retention hints in the body. +models keep explicit JSON ``cache_control`` content blocks for compatibility, +and also need the top-level ``cache_control`` hint on the OpenAI-compatible HF +Router path; the explicit markers alone are accepted there but do not produce +cache writes. OpenAI models cache eligible prefixes automatically and accept +routing/retention hints in the body. Headers like ``X-OpenRouter-Cache`` control response caching, not prompt caching through this route. """ @@ -67,6 +70,9 @@ def with_prompt_cache_params( if _is_openai_gpt55(llm_params): updates["prompt_cache_key"] = stable_session_id + if _uses_explicit_cache_control(llm_params): + updates["cache_control"] = dict(_CACHE_CONTROL) + if _is_openai_gpt55(llm_params): updates["prompt_cache_retention"] = "24h" diff --git a/agent/core/telemetry.py b/agent/core/telemetry.py index f71078b7..43b359c5 100644 --- a/agent/core/telemetry.py +++ b/agent/core/telemetry.py @@ -52,14 +52,18 @@ def _g(name, default=0): cache_read = _g("cache_read_input_tokens") cache_creation = _g("cache_creation_input_tokens") - - if not cache_read: - details = _g("prompt_tokens_details", None) - if details is not None: - if isinstance(details, dict): - cache_read = details.get("cached_tokens", 0) or 0 - else: - cache_read = getattr(details, "cached_tokens", 0) or 0 + details = _g("prompt_tokens_details", None) + + if not cache_read and details is not None: + if isinstance(details, dict): + cache_read = details.get("cached_tokens", 0) or 0 + else: + cache_read = getattr(details, "cached_tokens", 0) or 0 + if not cache_creation and details is not None: + if isinstance(details, dict): + cache_creation = details.get("cache_write_tokens", 0) or 0 + else: + cache_creation = getattr(details, "cache_write_tokens", 0) or 0 return { "prompt_tokens": int(prompt), diff --git a/frontend/src/components/UsageMeter.tsx b/frontend/src/components/UsageMeter.tsx index 55f054cc..89a0d2f4 100644 --- a/frontend/src/components/UsageMeter.tsx +++ b/frontend/src/components/UsageMeter.tsx @@ -230,7 +230,7 @@ export default function UsageMeter() { Usage - Billing window resets when you switch back to a task. + Estimated from HF account usage per session. {error ? ( diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 95ec8ebf..424b8536 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -182,10 +182,19 @@ def test_prompt_cache_params_add_session_id_for_fal_router_model(): cached_params = with_prompt_cache_params(llm_params, session_id="session-1") assert cached_params is not llm_params - assert cached_params["extra_body"] == {"session_id": "session-1"} + assert cached_params["extra_body"] == { + "session_id": "session-1", + "cache_control": {"type": "ephemeral"}, + } assert "extra_body" not in llm_params +def test_prompt_cache_params_adds_anthropic_cache_control_without_session_id(): + cached_params = with_prompt_cache_params(_anthropic_fal_params()) + + assert cached_params["extra_body"] == {"cache_control": {"type": "ephemeral"}} + + def test_prompt_cache_params_merges_gpt55_cache_hints(): llm_params = { **_gpt55_fal_params(), diff --git a/tests/unit/test_telemetry_usage.py b/tests/unit/test_telemetry_usage.py index 4a1f3db9..dfb27bbb 100644 --- a/tests/unit/test_telemetry_usage.py +++ b/tests/unit/test_telemetry_usage.py @@ -13,6 +13,25 @@ async def send_event(self, event): self.events.append(event) +def test_extract_usage_reads_hf_router_cache_write_tokens(): + response = SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=100, + completion_tokens=10, + total_tokens=110, + prompt_tokens_details=SimpleNamespace( + cached_tokens=80, + cache_write_tokens=20, + ), + ) + ) + + usage = telemetry.extract_usage(response) + + assert usage["cache_read_tokens"] == 80 + assert usage["cache_creation_tokens"] == 20 + + @pytest.mark.asyncio async def test_record_hf_job_complete_emits_runtime_cost(monkeypatch): async def fake_catalog():