diff --git a/examples/chat/python/src/streaming/a2ui_partial_handler.py b/examples/chat/python/src/streaming/a2ui_partial_handler.py index 8f5a89ea..f9f9c950 100644 --- a/examples/chat/python/src/streaming/a2ui_partial_handler.py +++ b/examples/chat/python/src/streaming/a2ui_partial_handler.py @@ -1,20 +1,30 @@ # SPDX-License-Identifier: MIT """Streaming callback handler that sidebands a parent LLM's growing tool_call.arguments as A2UI-partial custom events. Listens to LangChain's -on_chat_model_stream events; per tool_call_id, concatenates argument +on_llm_new_token events (the canonical streaming-token callback for +both chat models and legacy LLMs); per tool_call_id, concatenates argument deltas and dispatches each cumulative state via adispatch_custom_event. The frontend bridge (libs/chat partial-args-bridge) consumes these. """ from __future__ import annotations from typing import Any +from uuid import UUID from langchain_core.callbacks import AsyncCallbackHandler, adispatch_custom_event +from langchain_core.outputs import ChatGenerationChunk, GenerationChunk class A2uiPartialHandler(AsyncCallbackHandler): """Track per-tool_call_id cumulative arguments; dispatch a2ui-partial - custom events when the cumulative string grows.""" + custom events when the cumulative string grows. + + Hooks into `on_llm_new_token` — the canonical streaming-token callback + fired by ChatOpenAI when `streaming=True` is enabled. For each chunk + we inspect the embedded `chunk.message.tool_call_chunks` list (only + populated when the LLM is mid-stream of a tool_call) and forward any + delta belonging to our target tool name to the frontend. + """ def __init__(self, tool_name: str = "render_a2ui_surface") -> None: super().__init__() @@ -22,16 +32,23 @@ def __init__(self, tool_name: str = "render_a2ui_surface") -> None: # tool_call_id -> cumulative args string self._buffers: dict[str, str] = {} - async def on_chat_model_stream( + async def on_llm_new_token( self, - chunk: Any, + token: str, *, - run_id: str | None = None, + chunk: ChatGenerationChunk | GenerationChunk | None = None, + run_id: UUID | None = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: - # `chunk` is an AIMessageChunk. Each chunk may carry multiple - # tool_call_chunks (e.g. interleaved across concurrent tool_calls). - tool_call_chunks = getattr(chunk, "tool_call_chunks", None) or [] + # We only care about chat-model chunks that carry tool_call_chunks. + if chunk is None: + return + message = getattr(chunk, "message", None) + if message is None: + return + tool_call_chunks = getattr(message, "tool_call_chunks", None) or [] for tc in tool_call_chunks: name = tc.get("name") or "" call_id = tc.get("id") diff --git a/examples/chat/python/tests/test_a2ui_partial_handler.py b/examples/chat/python/tests/test_a2ui_partial_handler.py index 7eaa356f..5ed26b85 100644 --- a/examples/chat/python/tests/test_a2ui_partial_handler.py +++ b/examples/chat/python/tests/test_a2ui_partial_handler.py @@ -1,34 +1,47 @@ -"""Tests for A2uiPartialHandler — drives canned on_chat_model_stream events.""" +"""Tests for A2uiPartialHandler — drives canned on_llm_new_token events.""" from unittest.mock import patch, AsyncMock +from uuid import uuid4 import pytest from langchain_core.messages import AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk from src.streaming.a2ui_partial_handler import A2uiPartialHandler +def _make_chunk(tool_call_chunks: list[dict]) -> ChatGenerationChunk: + """Wrap an AIMessageChunk in a ChatGenerationChunk the way the real + LangChain streaming callback path does.""" + return ChatGenerationChunk( + text="", + message=AIMessageChunk(content="", tool_call_chunks=tool_call_chunks), + ) + + class TestA2uiPartialHandler: @pytest.mark.asyncio async def test_dispatches_event_when_chunk_grows_args(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: - chunk = AIMessageChunk(content="", tool_call_chunks=[ + chunk = _make_chunk([ {"id": "tc-1", "name": "render_a2ui_surface", "args": "{\"envelopes\":[", "index": 0}, ]) - await handler.on_chat_model_stream(chunk, run_id="r1") + await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4()) mock.assert_awaited_once_with("a2ui-partial", {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["}) @pytest.mark.asyncio async def test_concatenates_args_across_chunks_same_tool_call_id(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: - await handler.on_chat_model_stream( - AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), - run_id="r1", + await handler.on_llm_new_token( + "", + chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), + run_id=uuid4(), ) - await handler.on_chat_model_stream( - AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]), - run_id="r1", + await handler.on_llm_new_token( + "", + chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]), + run_id=uuid4(), ) # Second dispatch carries the cumulative string. assert mock.await_count == 2 @@ -40,9 +53,10 @@ async def test_concatenates_args_across_chunks_same_tool_call_id(self): async def test_ignores_chunks_for_unrelated_tools(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: - await handler.on_chat_model_stream( - AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]), - run_id="r1", + await handler.on_llm_new_token( + "", + chunk=_make_chunk([{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]), + run_id=uuid4(), ) mock.assert_not_awaited() @@ -52,13 +66,15 @@ async def test_no_dispatch_when_args_did_not_grow(self): with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: # First chunk grows the buffer; second chunk has empty args delta # (a no-op chunk from the model) and must not re-dispatch. - await handler.on_chat_model_stream( - AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), - run_id="r1", + await handler.on_llm_new_token( + "", + chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), + run_id=uuid4(), ) - await handler.on_chat_model_stream( - AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]), - run_id="r1", + await handler.on_llm_new_token( + "", + chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]), + run_id=uuid4(), ) mock.assert_awaited_once() @@ -66,13 +82,23 @@ async def test_no_dispatch_when_args_did_not_grow(self): async def test_per_tool_call_id_state_isolation(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: - await handler.on_chat_model_stream( - AIMessageChunk(content="", tool_call_chunks=[ + await handler.on_llm_new_token( + "", + chunk=_make_chunk([ {"id": "tc-A", "name": "render_a2ui_surface", "args": "{", "index": 0}, {"id": "tc-B", "name": "render_a2ui_surface", "args": "[", "index": 1}, ]), - run_id="r1", + run_id=uuid4(), ) assert mock.await_count == 2 ids = {call.args[1]["tool_call_id"] for call in mock.await_args_list} assert ids == {"tc-A", "tc-B"} + + @pytest.mark.asyncio + async def test_ignores_token_event_without_chunk_message(self): + """Some emitters of on_llm_new_token may pass chunk=None (legacy LLM + path). Handler must silently skip — no crash, no dispatch.""" + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + await handler.on_llm_new_token("some token", chunk=None, run_id=uuid4()) + mock.assert_not_awaited() diff --git a/examples/chat/python/tests/test_streaming_smoke.py b/examples/chat/python/tests/test_streaming_smoke.py index 8dee956c..04e3b979 100644 --- a/examples/chat/python/tests/test_streaming_smoke.py +++ b/examples/chat/python/tests/test_streaming_smoke.py @@ -1,38 +1,37 @@ -"""Integration smoke: the generate node, when invoked with a canned chat-model -stream, dispatches a2ui-partial events and the final state has the -single-bubble shape from PR #255.""" +"""Integration smoke: when fed a canned chat-model token stream, the +A2uiPartialHandler dispatches at least 3 a2ui-partial custom events +carrying the cumulative tool_call.args string.""" import json from unittest.mock import patch, AsyncMock +from uuid import uuid4 import pytest from langchain_core.messages import AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk from src.streaming.a2ui_partial_handler import A2uiPartialHandler -def _make_canned_stream() -> list[AIMessageChunk]: - """Five chunks of growing args for one tool_call to render_a2ui_surface.""" +def _make_canned_stream() -> list[ChatGenerationChunk]: + """Five chunks of growing args for one tool_call to render_a2ui_surface, + wrapped in ChatGenerationChunk as LangChain emits them from the + on_llm_new_token callback.""" + deltas = [ + '{"envelopes":[', + '{"surfaceUpdate":{"surfaceId":"s","components":[{"id":"root","type":"text","props":{}}]}},', + '{"beginRendering":{"surfaceId":"s","root":"root"}},', + '{"dataModelUpdate":{"surfaceId":"s","contents":[{"key":"text","valueString":"hi"}]}}', + "]}", + ] return [ - AIMessageChunk(content="", tool_call_chunks=[{ - "id": "tc-1", "name": "render_a2ui_surface", "index": 0, - "args": '{"envelopes":[', - }]), - AIMessageChunk(content="", tool_call_chunks=[{ - "id": "tc-1", "name": "render_a2ui_surface", "index": 0, - "args": '{"surfaceUpdate":{"surfaceId":"s","components":[{"id":"root","type":"text","props":{}}]}},', - }]), - AIMessageChunk(content="", tool_call_chunks=[{ - "id": "tc-1", "name": "render_a2ui_surface", "index": 0, - "args": '{"beginRendering":{"surfaceId":"s","root":"root"}},', - }]), - AIMessageChunk(content="", tool_call_chunks=[{ - "id": "tc-1", "name": "render_a2ui_surface", "index": 0, - "args": '{"dataModelUpdate":{"surfaceId":"s","contents":[{"key":"text","valueString":"hi"}]}}', - }]), - AIMessageChunk(content="", tool_call_chunks=[{ - "id": "tc-1", "name": "render_a2ui_surface", "index": 0, - "args": "]}", - }]), + ChatGenerationChunk( + text="", + message=AIMessageChunk(content="", tool_call_chunks=[{ + "id": "tc-1", "name": "render_a2ui_surface", "index": 0, + "args": delta, + }]), + ) + for delta in deltas ] @@ -42,7 +41,7 @@ async def test_handler_dispatches_per_chunk(): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: for chunk in _make_canned_stream(): - await handler.on_chat_model_stream(chunk, run_id="r1") + await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4()) assert mock.await_count >= 3 # Last cumulative string is the full envelope JSON. last = mock.await_args_list[-1].args[1]