Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_chat_model_cached,
get_web_search_tool,
extract_web_search_count,
extract_web_search_events,
convert_messages,
extract_usage,
)
Expand Down Expand Up @@ -445,6 +446,9 @@ def generate():
# blocks, citations, grounding metadata) can be counted for billing
# once the stream completes.
merged_chunk = None
# Tracks web-search blocks already surfaced to the client so the
# live "searching the web" status is emitted once per search.
seen_web_searches: set = set()

try:
if image_output_model:
Expand Down Expand Up @@ -517,6 +521,26 @@ def generate():
chunk if merged_chunk is None else merged_chunk + chunk
)

# --- Web-search status (UI only) ---
# When the provider performs a native web search mid-
# stream, emit a lightweight status frame so the client
# can show a "searching the web" indicator. The frame
# carries no content delta, so it is not part of the
# signed output hash and does not affect billing.
for ws_event in extract_web_search_events(
chunk, seen_web_searches
):
status_data: dict[str, Any] = {
"choices": [
{"delta": {}, "index": 0, "finish_reason": None}
],
"model": chat_request.model,
"web_search": {"status": "searching"},
}
if ws_event.get("query"):
status_data["web_search"]["query"] = ws_event["query"]
yield f"data: {json.dumps(status_data)}\n\n"

# --- Text content ---
if chunk.content:
if isinstance(chunk.content, str):
Expand Down
58 changes: 58 additions & 0 deletions tee_gateway/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,61 @@ def extract_web_search_count(message) -> int:
count += 1

return count


def _web_search_query_from_block(block: Dict[str, Any]) -> Optional[str]:
"""Best-effort extraction of the search query from a web-search content block.

Anthropic ``server_tool_use`` blocks carry the query under ``input`` and
OpenAI ``web_search_call`` blocks under ``action`` (when present). During
streaming the query may not be fully accumulated yet, in which case we
return None and the client shows a generic "searching" message.
"""
for key in ("input", "action"):
value = block.get(key)
if isinstance(value, dict):
query = value.get("query")
if isinstance(query, str) and query:
return query
return None


def extract_web_search_events(message, seen_ids: set) -> List[Dict[str, Any]]:
"""Detect web-search activity in a single streamed chunk, for live UI status.

Returns one event dict ``{"query": <str|None>}`` per newly-seen web-search
block so the client can surface a "searching the web" indicator while the
model browses. ``seen_ids`` is mutated to dedupe blocks that span multiple
chunks (Anthropic streams a block's input as incremental JSON deltas, so the
same block reappears across chunks under a stable id/index).

This is a UI-only signal: it is independent of ``extract_web_search_count``,
which counts the completed/accumulated message for billing.
"""
if message is None:
return []

events: List[Dict[str, Any]] = []
content = getattr(message, "content", None)
if not isinstance(content, list):
return events

for block in content:
if not isinstance(block, dict):
continue
btype = block.get("type")
is_web_search = btype == "web_search_call" or (
btype == "server_tool_use" and block.get("name") == "web_search"
)
if not is_web_search:
continue

# Dedupe by stable id (falling back to the streamed block index) so we
# emit one event per actual search rather than one per partial chunk.
block_key = (btype, block.get("id") or block.get("index"))
if block_key in seen_ids:
continue
seen_ids.add(block_key)
events.append({"query": _web_search_query_from_block(block)})

return events
189 changes: 189 additions & 0 deletions tee_gateway/test/test_web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tee_gateway.llm_backend import (
get_web_search_tool,
extract_web_search_count,
extract_web_search_events,
)
from tee_gateway.pricing import SessionCost, compute_session_cost
from tee_gateway.controllers.chat_controller import create_chat_completion
Expand Down Expand Up @@ -130,6 +131,98 @@ def test_google_grounding_counts_as_one_request(self):
self.assertEqual(extract_web_search_count(msg), 1)


# ---------------------------------------------------------------------------
# llm_backend.extract_web_search_events (live UI status)
# ---------------------------------------------------------------------------


class TestExtractWebSearchEvents(unittest.TestCase):
def test_none_message(self):
self.assertEqual(extract_web_search_events(None, set()), [])

def test_plain_text_chunk_has_no_events(self):
self.assertEqual(extract_web_search_events(AIMessage(content="hi"), set()), [])

def test_string_content_chunk_has_no_events(self):
# Streamed text deltas arrive as plain strings, not block lists.
self.assertEqual(extract_web_search_events(AIMessage(content=""), set()), [])

def test_openai_web_search_call_emits_event(self):
msg = AIMessage(content=[{"type": "web_search_call", "id": "ws_1"}])
events = extract_web_search_events(msg, set())
self.assertEqual(len(events), 1)
self.assertIsNone(events[0]["query"])

def test_anthropic_server_tool_use_emits_event_with_query(self):
msg = AIMessage(
content=[
{
"type": "server_tool_use",
"name": "web_search",
"id": "srv_1",
"input": {"query": "latest news"},
}
]
)
events = extract_web_search_events(msg, set())
self.assertEqual(events, [{"query": "latest news"}])

def test_openai_query_from_action(self):
msg = AIMessage(
content=[
{
"type": "web_search_call",
"id": "ws_1",
"action": {"query": "weather today"},
}
]
)
events = extract_web_search_events(msg, set())
self.assertEqual(events, [{"query": "weather today"}])

def test_dedupes_block_across_chunks_by_id(self):
seen: set = set()
# Same search block reappears across chunks (Anthropic input deltas).
first = AIMessage(
content=[{"type": "server_tool_use", "name": "web_search", "id": "srv_1"}]
)
second = AIMessage(
content=[
{
"type": "server_tool_use",
"name": "web_search",
"id": "srv_1",
"input": {"query": "now complete"},
}
]
)
self.assertEqual(len(extract_web_search_events(first, seen)), 1)
# Already seen -> no duplicate event on the next chunk.
self.assertEqual(extract_web_search_events(second, seen), [])

def test_distinct_searches_each_emit(self):
seen: set = set()
first = AIMessage(content=[{"type": "web_search_call", "id": "ws_1"}])
second = AIMessage(content=[{"type": "web_search_call", "id": "ws_2"}])
self.assertEqual(len(extract_web_search_events(first, seen)), 1)
self.assertEqual(len(extract_web_search_events(second, seen)), 1)

def test_dedupes_by_index_when_no_id(self):
seen: set = set()
chunk = AIMessage(content=[{"type": "web_search_call", "index": 0}])
self.assertEqual(len(extract_web_search_events(chunk, seen)), 1)
self.assertEqual(extract_web_search_events(chunk, seen), [])

def test_non_search_blocks_ignored(self):
msg = AIMessage(
content=[
{"type": "text", "text": "answer"},
{"type": "web_search_tool_result", "content": []},
]
)
self.assertEqual(extract_web_search_events(msg, set()), [])


# ---------------------------------------------------------------------------
# pricing.compute_session_cost with web search
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -274,5 +367,101 @@ def test_no_web_search_does_not_bind_or_bill_search(
self.assertEqual(mock_cost.call_args.kwargs["web_search_count"], 0)


class TestChatControllerWebSearchStreaming(unittest.TestCase):
@staticmethod
def _collect_sse(response) -> str:
parts = []
for part in response.response:
parts.append(part.decode("utf-8") if isinstance(part, bytes) else part)
return "".join(parts)

@patch("tee_gateway.controllers.chat_controller.compute_session_cost")
@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
@patch("tee_gateway.controllers.chat_controller.connexion")
def test_streaming_emits_web_search_status_frame(
self, mock_connexion, mock_get_model, mock_get_tee_keys, mock_cost
):
from langchain_core.messages import AIMessageChunk

mock_connexion.request.is_json = True
mock_connexion.request.get_json.return_value = {
"model": "claude-sonnet-4-5",
"messages": [{"role": "user", "content": "latest news?"}],
"web_search": True,
"stream": True,
}

# The provider streams a web-search block (twice, as input accumulates)
# before the answer text — the status frame must be emitted only once.
chunks = [
AIMessageChunk(
content=[
{
"type": "server_tool_use",
"name": "web_search",
"id": "srv_1",
"input": {},
}
]
),
AIMessageChunk(
content=[
{
"type": "server_tool_use",
"name": "web_search",
"id": "srv_1",
"input": {"query": "latest news"},
}
]
),
AIMessageChunk(content="Here is the news."),
]
model = Mock()
model.stream.return_value = chunks
model.bind_tools.return_value = model
mock_get_model.return_value = model
mock_get_tee_keys.return_value = _mock_tee_keys()
mock_cost.return_value = None

response = create_chat_completion(None)
body = self._collect_sse(response)

# Exactly one web-search status frame for the single (deduped) search.
self.assertEqual(body.count('"web_search"'), 1)
self.assertIn('"status": "searching"', body)
# The answer text still streams as a normal content delta.
self.assertIn("Here is the news.", body)
self.assertIn("[DONE]", body)

@patch("tee_gateway.controllers.chat_controller.compute_session_cost")
@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
@patch("tee_gateway.controllers.chat_controller.connexion")
def test_streaming_without_web_search_emits_no_status_frame(
self, mock_connexion, mock_get_model, mock_get_tee_keys, mock_cost
):
from langchain_core.messages import AIMessageChunk

mock_connexion.request.is_json = True
mock_connexion.request.get_json.return_value = {
"model": "gpt-4.1",
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
}
model = Mock()
model.stream.return_value = [AIMessageChunk(content="hi")]
model.bind_tools.return_value = model
mock_get_model.return_value = model
mock_get_tee_keys.return_value = _mock_tee_keys()
mock_cost.return_value = None

response = create_chat_completion(None)
body = self._collect_sse(response)

self.assertNotIn('"web_search"', body)
self.assertIn("hi", body)


if __name__ == "__main__":
unittest.main()
Loading