diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index f96ce85..d023ee9 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -29,6 +29,7 @@ get_chat_model_cached, get_web_search_tool, extract_web_search_count, + extract_web_search_events, convert_messages, extract_usage, ) @@ -445,6 +446,9 @@ def generate(): # blocks, citations, grounding metadata) can be counted for billing # once the stream completes. merged_chunk = None + # Tracks web-search blocks already surfaced to the client so the + # live "searching the web" status is emitted once per search. + seen_web_searches: set = set() try: if image_output_model: @@ -517,6 +521,26 @@ def generate(): chunk if merged_chunk is None else merged_chunk + chunk ) + # --- Web-search status (UI only) --- + # When the provider performs a native web search mid- + # stream, emit a lightweight status frame so the client + # can show a "searching the web" indicator. The frame + # carries no content delta, so it is not part of the + # signed output hash and does not affect billing. + for ws_event in extract_web_search_events( + chunk, seen_web_searches + ): + status_data: dict[str, Any] = { + "choices": [ + {"delta": {}, "index": 0, "finish_reason": None} + ], + "model": chat_request.model, + "web_search": {"status": "searching"}, + } + if ws_event.get("query"): + status_data["web_search"]["query"] = ws_event["query"] + yield f"data: {json.dumps(status_data)}\n\n" + # --- Text content --- if chunk.content: if isinstance(chunk.content, str): diff --git a/tee_gateway/llm_backend.py b/tee_gateway/llm_backend.py index 74c4721..29aa8bb 100644 --- a/tee_gateway/llm_backend.py +++ b/tee_gateway/llm_backend.py @@ -433,3 +433,61 @@ def extract_web_search_count(message) -> int: count += 1 return count + + +def _web_search_query_from_block(block: Dict[str, Any]) -> Optional[str]: + """Best-effort extraction of the search query from a web-search content block. + + Anthropic ``server_tool_use`` blocks carry the query under ``input`` and + OpenAI ``web_search_call`` blocks under ``action`` (when present). During + streaming the query may not be fully accumulated yet, in which case we + return None and the client shows a generic "searching" message. + """ + for key in ("input", "action"): + value = block.get(key) + if isinstance(value, dict): + query = value.get("query") + if isinstance(query, str) and query: + return query + return None + + +def extract_web_search_events(message, seen_ids: set) -> List[Dict[str, Any]]: + """Detect web-search activity in a single streamed chunk, for live UI status. + + Returns one event dict ``{"query": }`` per newly-seen web-search + block so the client can surface a "searching the web" indicator while the + model browses. ``seen_ids`` is mutated to dedupe blocks that span multiple + chunks (Anthropic streams a block's input as incremental JSON deltas, so the + same block reappears across chunks under a stable id/index). + + This is a UI-only signal: it is independent of ``extract_web_search_count``, + which counts the completed/accumulated message for billing. + """ + if message is None: + return [] + + events: List[Dict[str, Any]] = [] + content = getattr(message, "content", None) + if not isinstance(content, list): + return events + + for block in content: + if not isinstance(block, dict): + continue + btype = block.get("type") + is_web_search = btype == "web_search_call" or ( + btype == "server_tool_use" and block.get("name") == "web_search" + ) + if not is_web_search: + continue + + # Dedupe by stable id (falling back to the streamed block index) so we + # emit one event per actual search rather than one per partial chunk. + block_key = (btype, block.get("id") or block.get("index")) + if block_key in seen_ids: + continue + seen_ids.add(block_key) + events.append({"query": _web_search_query_from_block(block)}) + + return events diff --git a/tee_gateway/test/test_web_search.py b/tee_gateway/test/test_web_search.py index 277dcd3..c24580e 100644 --- a/tee_gateway/test/test_web_search.py +++ b/tee_gateway/test/test_web_search.py @@ -24,6 +24,7 @@ from tee_gateway.llm_backend import ( get_web_search_tool, extract_web_search_count, + extract_web_search_events, ) from tee_gateway.pricing import SessionCost, compute_session_cost from tee_gateway.controllers.chat_controller import create_chat_completion @@ -130,6 +131,98 @@ def test_google_grounding_counts_as_one_request(self): self.assertEqual(extract_web_search_count(msg), 1) +# --------------------------------------------------------------------------- +# llm_backend.extract_web_search_events (live UI status) +# --------------------------------------------------------------------------- + + +class TestExtractWebSearchEvents(unittest.TestCase): + def test_none_message(self): + self.assertEqual(extract_web_search_events(None, set()), []) + + def test_plain_text_chunk_has_no_events(self): + self.assertEqual(extract_web_search_events(AIMessage(content="hi"), set()), []) + + def test_string_content_chunk_has_no_events(self): + # Streamed text deltas arrive as plain strings, not block lists. + self.assertEqual(extract_web_search_events(AIMessage(content=""), set()), []) + + def test_openai_web_search_call_emits_event(self): + msg = AIMessage(content=[{"type": "web_search_call", "id": "ws_1"}]) + events = extract_web_search_events(msg, set()) + self.assertEqual(len(events), 1) + self.assertIsNone(events[0]["query"]) + + def test_anthropic_server_tool_use_emits_event_with_query(self): + msg = AIMessage( + content=[ + { + "type": "server_tool_use", + "name": "web_search", + "id": "srv_1", + "input": {"query": "latest news"}, + } + ] + ) + events = extract_web_search_events(msg, set()) + self.assertEqual(events, [{"query": "latest news"}]) + + def test_openai_query_from_action(self): + msg = AIMessage( + content=[ + { + "type": "web_search_call", + "id": "ws_1", + "action": {"query": "weather today"}, + } + ] + ) + events = extract_web_search_events(msg, set()) + self.assertEqual(events, [{"query": "weather today"}]) + + def test_dedupes_block_across_chunks_by_id(self): + seen: set = set() + # Same search block reappears across chunks (Anthropic input deltas). + first = AIMessage( + content=[{"type": "server_tool_use", "name": "web_search", "id": "srv_1"}] + ) + second = AIMessage( + content=[ + { + "type": "server_tool_use", + "name": "web_search", + "id": "srv_1", + "input": {"query": "now complete"}, + } + ] + ) + self.assertEqual(len(extract_web_search_events(first, seen)), 1) + # Already seen -> no duplicate event on the next chunk. + self.assertEqual(extract_web_search_events(second, seen), []) + + def test_distinct_searches_each_emit(self): + seen: set = set() + first = AIMessage(content=[{"type": "web_search_call", "id": "ws_1"}]) + second = AIMessage(content=[{"type": "web_search_call", "id": "ws_2"}]) + self.assertEqual(len(extract_web_search_events(first, seen)), 1) + self.assertEqual(len(extract_web_search_events(second, seen)), 1) + + def test_dedupes_by_index_when_no_id(self): + seen: set = set() + chunk = AIMessage(content=[{"type": "web_search_call", "index": 0}]) + self.assertEqual(len(extract_web_search_events(chunk, seen)), 1) + self.assertEqual(extract_web_search_events(chunk, seen), []) + + def test_non_search_blocks_ignored(self): + msg = AIMessage( + content=[ + {"type": "text", "text": "answer"}, + {"type": "web_search_tool_result", "content": []}, + ] + ) + self.assertEqual(extract_web_search_events(msg, set()), []) + + # --------------------------------------------------------------------------- # pricing.compute_session_cost with web search # --------------------------------------------------------------------------- @@ -274,5 +367,101 @@ def test_no_web_search_does_not_bind_or_bill_search( self.assertEqual(mock_cost.call_args.kwargs["web_search_count"], 0) +class TestChatControllerWebSearchStreaming(unittest.TestCase): + @staticmethod + def _collect_sse(response) -> str: + parts = [] + for part in response.response: + parts.append(part.decode("utf-8") if isinstance(part, bytes) else part) + return "".join(parts) + + @patch("tee_gateway.controllers.chat_controller.compute_session_cost") + @patch("tee_gateway.controllers.chat_controller.get_tee_keys") + @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") + @patch("tee_gateway.controllers.chat_controller.connexion") + def test_streaming_emits_web_search_status_frame( + self, mock_connexion, mock_get_model, mock_get_tee_keys, mock_cost + ): + from langchain_core.messages import AIMessageChunk + + mock_connexion.request.is_json = True + mock_connexion.request.get_json.return_value = { + "model": "claude-sonnet-4-5", + "messages": [{"role": "user", "content": "latest news?"}], + "web_search": True, + "stream": True, + } + + # The provider streams a web-search block (twice, as input accumulates) + # before the answer text — the status frame must be emitted only once. + chunks = [ + AIMessageChunk( + content=[ + { + "type": "server_tool_use", + "name": "web_search", + "id": "srv_1", + "input": {}, + } + ] + ), + AIMessageChunk( + content=[ + { + "type": "server_tool_use", + "name": "web_search", + "id": "srv_1", + "input": {"query": "latest news"}, + } + ] + ), + AIMessageChunk(content="Here is the news."), + ] + model = Mock() + model.stream.return_value = chunks + model.bind_tools.return_value = model + mock_get_model.return_value = model + mock_get_tee_keys.return_value = _mock_tee_keys() + mock_cost.return_value = None + + response = create_chat_completion(None) + body = self._collect_sse(response) + + # Exactly one web-search status frame for the single (deduped) search. + self.assertEqual(body.count('"web_search"'), 1) + self.assertIn('"status": "searching"', body) + # The answer text still streams as a normal content delta. + self.assertIn("Here is the news.", body) + self.assertIn("[DONE]", body) + + @patch("tee_gateway.controllers.chat_controller.compute_session_cost") + @patch("tee_gateway.controllers.chat_controller.get_tee_keys") + @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") + @patch("tee_gateway.controllers.chat_controller.connexion") + def test_streaming_without_web_search_emits_no_status_frame( + self, mock_connexion, mock_get_model, mock_get_tee_keys, mock_cost + ): + from langchain_core.messages import AIMessageChunk + + mock_connexion.request.is_json = True + mock_connexion.request.get_json.return_value = { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + } + model = Mock() + model.stream.return_value = [AIMessageChunk(content="hi")] + model.bind_tools.return_value = model + mock_get_model.return_value = model + mock_get_tee_keys.return_value = _mock_tee_keys() + mock_cost.return_value = None + + response = create_chat_completion(None) + body = self._collect_sse(response) + + self.assertNotIn('"web_search"', body) + self.assertIn("hi", body) + + if __name__ == "__main__": unittest.main()