diff --git a/examples/chat/python/src/streaming/a2ui_partial_handler.py b/examples/chat/python/src/streaming/a2ui_partial_handler.py index f9f9c950..378b0ec5 100644 --- a/examples/chat/python/src/streaming/a2ui_partial_handler.py +++ b/examples/chat/python/src/streaming/a2ui_partial_handler.py @@ -1,10 +1,25 @@ # SPDX-License-Identifier: MIT """Streaming callback handler that sidebands a parent LLM's growing -tool_call.arguments as A2UI-partial custom events. Listens to LangChain's -on_llm_new_token events (the canonical streaming-token callback for -both chat models and legacy LLMs); per tool_call_id, concatenates argument -deltas and dispatches each cumulative state via adispatch_custom_event. -The frontend bridge (libs/chat partial-args-bridge) consumes these. +tool_call.arguments as A2UI-partial custom events. + +Hooks into LangChain's `on_llm_new_token` callback (the canonical +streaming-token callback fired by ChatOpenAI when streaming=True is +enabled). Two delta shapes are observed in production depending on +which OpenAI API the model uses: + + - Chat Completions API: `chunk.message.tool_call_chunks` carries the + classic `[{name, id, args, index}, ...]` deltas. + - Responses API (gpt-5 family): `chunk.message.content` carries a + list of content blocks; tool-call deltas appear as + `{type: 'function_call', name?: str, call_id?: str, arguments: str, + index: int}`. The first block for each call carries `name` and + `call_id`; subsequent blocks for the same `index` carry only the + `arguments` delta. + +The handler covers both shapes. For each delta belonging to our target +tool name, it concatenates arguments per `tool_call_id` and dispatches +an `a2ui-partial` custom event when the cumulative string grows. The +frontend bridge (libs/chat partial-args-bridge) consumes these. """ from __future__ import annotations @@ -17,20 +32,18 @@ class A2uiPartialHandler(AsyncCallbackHandler): """Track per-tool_call_id cumulative arguments; dispatch a2ui-partial - custom events when the cumulative string grows. - - Hooks into `on_llm_new_token` — the canonical streaming-token callback - fired by ChatOpenAI when `streaming=True` is enabled. For each chunk - we inspect the embedded `chunk.message.tool_call_chunks` list (only - populated when the LLM is mid-stream of a tool_call) and forward any - delta belonging to our target tool name to the frontend. - """ + custom events when the cumulative string grows.""" def __init__(self, tool_name: str = "render_a2ui_surface") -> None: super().__init__() self._tool_name = tool_name # tool_call_id -> cumulative args string self._buffers: dict[str, str] = {} + # Responses-API index -> tool_call_id mapping. The first content + # block for each tool call carries name + call_id; subsequent + # blocks for that `index` carry only the args delta, so we + # remember the mapping to attribute later deltas correctly. + self._index_to_call_id: dict[int, str] = {} async def on_llm_new_token( self, @@ -42,26 +55,77 @@ async def on_llm_new_token( tags: list[str] | None = None, **kwargs: Any, ) -> None: - # We only care about chat-model chunks that carry tool_call_chunks. if chunk is None: return message = getattr(chunk, "message", None) if message is None: return - tool_call_chunks = getattr(message, "tool_call_chunks", None) or [] - for tc in tool_call_chunks: - name = tc.get("name") or "" - call_id = tc.get("id") - delta = tc.get("args") or "" - if name != self._tool_name or not call_id: - continue - existing = self._buffers.get(call_id, "") - updated = existing + delta - if updated == existing: - # No growth — don't re-dispatch the same payload. - continue - self._buffers[call_id] = updated - await adispatch_custom_event( - "a2ui-partial", - {"tool_call_id": call_id, "args_so_far": updated}, - ) + + # Path 1: Chat Completions API — classic tool_call_chunks list. + for tc in getattr(message, "tool_call_chunks", None) or []: + await self._handle_classic_chunk(tc) + + # Path 2: Responses API — content blocks with type='function_call'. + content = getattr(message, "content", None) + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "function_call": + await self._handle_responses_block(block) + + async def _handle_classic_chunk(self, tc: dict) -> None: + """Chat Completions delta: {id, name?, args?, index}.""" + name = tc.get("name") or "" + call_id = tc.get("id") + delta = tc.get("args") or "" + if not call_id: + return + # The first chunk for a call carries the name; later chunks may omit + # it but should be attributed to the same call_id (which IS carried + # on every chunk in this shape). + if name and name != self._tool_name: + return + await self._dispatch_delta(call_id, delta) + + async def _handle_responses_block(self, block: dict) -> None: + """Responses API delta: {type, name?, call_id?, arguments, index}. + + The first block for a call carries `name` and `call_id`. Subsequent + blocks for the same `index` carry only the `arguments` delta — we + remember the index→call_id mapping to attribute them correctly. + """ + index = block.get("index") + name = block.get("name") or "" + call_id = block.get("call_id") + delta = block.get("arguments") or "" + + if call_id: + # First block for this call. Validate target tool name. + if name and name != self._tool_name: + # Not our target — still remember the mapping is "ignore". + if isinstance(index, int): + self._index_to_call_id.pop(index, None) + return + if isinstance(index, int): + self._index_to_call_id[index] = call_id + else: + # Subsequent block — look up call_id by index. + if not isinstance(index, int): + return + call_id = self._index_to_call_id.get(index) + if not call_id: + # Either we never saw a first-block for this index, or it + # was for a non-target tool. Skip. + return + + await self._dispatch_delta(call_id, delta) + + async def _dispatch_delta(self, call_id: str, delta: str) -> None: + existing = self._buffers.get(call_id, "") + updated = existing + delta + if updated == existing: + return # No growth — suppress duplicate dispatch. + self._buffers[call_id] = updated + await adispatch_custom_event( + "a2ui-partial", + {"tool_call_id": call_id, "args_so_far": updated}, + ) diff --git a/examples/chat/python/tests/test_a2ui_partial_handler.py b/examples/chat/python/tests/test_a2ui_partial_handler.py index 5ed26b85..64d3858b 100644 --- a/examples/chat/python/tests/test_a2ui_partial_handler.py +++ b/examples/chat/python/tests/test_a2ui_partial_handler.py @@ -102,3 +102,58 @@ async def test_ignores_token_event_without_chunk_message(self): with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: await handler.on_llm_new_token("some token", chunk=None, run_id=uuid4()) mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_responses_api_function_call_content_blocks(self): + """gpt-5 / Responses API streams tool-call deltas as content blocks + of type='function_call' rather than tool_call_chunks. The first + block for each call carries name + call_id; subsequent blocks for + the same index carry only the args delta.""" + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + first = ChatGenerationChunk( + text="", + message=AIMessageChunk(content=[ + {"type": "function_call", "name": "render_a2ui_surface", + "call_id": "call_ABC", "id": "fc_1", + "arguments": "", "index": 1}, + ]), + ) + next1 = ChatGenerationChunk( + text="", + message=AIMessageChunk(content=[ + {"type": "function_call", "arguments": "{\"en", "index": 1}, + ]), + ) + next2 = ChatGenerationChunk( + text="", + message=AIMessageChunk(content=[ + {"type": "function_call", "arguments": "velopes", "index": 1}, + ]), + ) + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + await handler.on_llm_new_token("", chunk=first, run_id=uuid4()) + await handler.on_llm_new_token("", chunk=next1, run_id=uuid4()) + await handler.on_llm_new_token("", chunk=next2, run_id=uuid4()) + # First block has empty args — no dispatch. Two subsequent grow the buffer. + assert mock.await_count == 2 + call1 = mock.await_args_list[0].args[1] + call2 = mock.await_args_list[1].args[1] + assert call1["tool_call_id"] == "call_ABC" + assert call1["args_so_far"] == "{\"en" + assert call2["tool_call_id"] == "call_ABC" + assert call2["args_so_far"] == "{\"envelopes" + + @pytest.mark.asyncio + async def test_responses_api_ignores_non_target_tools(self): + """Responses-API content blocks for a different tool are ignored.""" + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + chunk = ChatGenerationChunk( + text="", + message=AIMessageChunk(content=[ + {"type": "function_call", "name": "search_documents", + "call_id": "call_X", "arguments": "{}", "index": 0}, + ]), + ) + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4()) + mock.assert_not_awaited()