From 26358c0168c1f3a3a42c07fe1e758ebe82f7f5ec Mon Sep 17 00:00:00 2001 From: Brian Love Date: Tue, 12 May 2026 15:00:22 -0700 Subject: [PATCH 1/2] feat(examples-chat): A2uiPartialHandler sidebands streaming envelopes Async callback handler tracking per-tool_call_id cumulative arguments from on_chat_model_stream events. Each growth in the cumulative string dispatches an a2ui-partial custom event carrying {tool_call_id, args_so_far}; the frontend partial-args-bridge consumes these and feeds envelopes into the A2UI surface store as they parse. --- .../src/streaming/a2ui_partial_handler.py | 50 ++++++++++++ .../python/tests/test_a2ui_partial_handler.py | 78 +++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 examples/chat/python/src/streaming/a2ui_partial_handler.py create mode 100644 examples/chat/python/tests/test_a2ui_partial_handler.py diff --git a/examples/chat/python/src/streaming/a2ui_partial_handler.py b/examples/chat/python/src/streaming/a2ui_partial_handler.py new file mode 100644 index 000000000..8f5a89ea6 --- /dev/null +++ b/examples/chat/python/src/streaming/a2ui_partial_handler.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: MIT +"""Streaming callback handler that sidebands a parent LLM's growing +tool_call.arguments as A2UI-partial custom events. Listens to LangChain's +on_chat_model_stream events; per tool_call_id, concatenates argument +deltas and dispatches each cumulative state via adispatch_custom_event. +The frontend bridge (libs/chat partial-args-bridge) consumes these. +""" +from __future__ import annotations + +from typing import Any + +from langchain_core.callbacks import AsyncCallbackHandler, adispatch_custom_event + + +class A2uiPartialHandler(AsyncCallbackHandler): + """Track per-tool_call_id cumulative arguments; dispatch a2ui-partial + custom events when the cumulative string grows.""" + + def __init__(self, tool_name: str = "render_a2ui_surface") -> None: + super().__init__() + self._tool_name = tool_name + # tool_call_id -> cumulative args string + self._buffers: dict[str, str] = {} + + async def on_chat_model_stream( + self, + chunk: Any, + *, + run_id: str | None = None, + **kwargs: Any, + ) -> None: + # `chunk` is an AIMessageChunk. Each chunk may carry multiple + # tool_call_chunks (e.g. interleaved across concurrent tool_calls). + tool_call_chunks = getattr(chunk, "tool_call_chunks", None) or [] + for tc in tool_call_chunks: + name = tc.get("name") or "" + call_id = tc.get("id") + delta = tc.get("args") or "" + if name != self._tool_name or not call_id: + continue + existing = self._buffers.get(call_id, "") + updated = existing + delta + if updated == existing: + # No growth — don't re-dispatch the same payload. + continue + self._buffers[call_id] = updated + await adispatch_custom_event( + "a2ui-partial", + {"tool_call_id": call_id, "args_so_far": updated}, + ) diff --git a/examples/chat/python/tests/test_a2ui_partial_handler.py b/examples/chat/python/tests/test_a2ui_partial_handler.py new file mode 100644 index 000000000..7eaa356f9 --- /dev/null +++ b/examples/chat/python/tests/test_a2ui_partial_handler.py @@ -0,0 +1,78 @@ +"""Tests for A2uiPartialHandler — drives canned on_chat_model_stream events.""" +from unittest.mock import patch, AsyncMock + +import pytest +from langchain_core.messages import AIMessageChunk + +from src.streaming.a2ui_partial_handler import A2uiPartialHandler + + +class TestA2uiPartialHandler: + @pytest.mark.asyncio + async def test_dispatches_event_when_chunk_grows_args(self): + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + chunk = AIMessageChunk(content="", tool_call_chunks=[ + {"id": "tc-1", "name": "render_a2ui_surface", "args": "{\"envelopes\":[", "index": 0}, + ]) + await handler.on_chat_model_stream(chunk, run_id="r1") + mock.assert_awaited_once_with("a2ui-partial", {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["}) + + @pytest.mark.asyncio + async def test_concatenates_args_across_chunks_same_tool_call_id(self): + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + await handler.on_chat_model_stream( + AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), + run_id="r1", + ) + await handler.on_chat_model_stream( + AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]), + run_id="r1", + ) + # Second dispatch carries the cumulative string. + assert mock.await_count == 2 + args = [call.args for call in mock.await_args_list] + assert args[0][1]["args_so_far"] == "{" + assert args[1][1]["args_so_far"] == "{\"x\":1}" + + @pytest.mark.asyncio + async def test_ignores_chunks_for_unrelated_tools(self): + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + await handler.on_chat_model_stream( + AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]), + run_id="r1", + ) + mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_dispatch_when_args_did_not_grow(self): + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + # First chunk grows the buffer; second chunk has empty args delta + # (a no-op chunk from the model) and must not re-dispatch. + await handler.on_chat_model_stream( + AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), + run_id="r1", + ) + await handler.on_chat_model_stream( + AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]), + run_id="r1", + ) + mock.assert_awaited_once() + + @pytest.mark.asyncio + async def test_per_tool_call_id_state_isolation(self): + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + await handler.on_chat_model_stream( + AIMessageChunk(content="", tool_call_chunks=[ + {"id": "tc-A", "name": "render_a2ui_surface", "args": "{", "index": 0}, + {"id": "tc-B", "name": "render_a2ui_surface", "args": "[", "index": 1}, + ]), + run_id="r1", + ) + assert mock.await_count == 2 + ids = {call.args[1]["tool_call_id"] for call in mock.await_args_list} + assert ids == {"tc-A", "tc-B"} From 65c9359d60ad88dcd128ea02d41593e00504bec2 Mon Sep 17 00:00:00 2001 From: Brian Love Date: Tue, 12 May 2026 15:00:28 -0700 Subject: [PATCH 2/2] feat(examples-chat): wire A2uiPartialHandler to generate node MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Attached only when gen_ui_mode='a2ui'. Sidebands the parent LLM's tool_call_chunks for render_a2ui_surface as a2ui-partial custom events. Together with the frontend partial-args bridge (claude/genui-streaming- frontend-bridge) and the envelope-tool refactor (claude/genui-streaming- envelope-tool), this realises the per-component fallback transition wired by PR #252 — surface mounts on first surfaceUpdate, components flip from fallback to real as dataModelUpdates stream in. --- examples/chat/python/src/graph.py | 10 +++- .../chat/python/tests/test_streaming_smoke.py | 52 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 examples/chat/python/tests/test_streaming_smoke.py diff --git a/examples/chat/python/src/graph.py b/examples/chat/python/src/graph.py index feb398cbd..78b6330c6 100644 --- a/examples/chat/python/src/graph.py +++ b/examples/chat/python/src/graph.py @@ -42,6 +42,7 @@ from langchain_core.tools import tool from langgraph_sdk import get_client +from src.streaming.a2ui_partial_handler import A2uiPartialHandler from src.streaming.envelope_tool import render_a2ui_surface from src.streaming.envelope_normalizer import normalize_envelope_args from src.schemas.a2ui_v1 import A2UI_V1_SCHEMA_PROMPT @@ -410,7 +411,14 @@ async def generate(state: State, config: RunnableConfig) -> dict: "entries. This lets the client mount the surface as early as possible." ) messages = [SystemMessage(content=system)] + state["messages"] - response = await llm.ainvoke(messages) + # When in a2ui mode, attach the partial-envelope sideband handler so + # the parent LLM's tool_call_chunks for render_a2ui_surface are + # dispatched as a2ui-partial custom events (consumed by the frontend + # partial-args bridge). + callbacks = [] + if gen_ui_mode == "a2ui": + callbacks.append(A2uiPartialHandler(tool_name="render_a2ui_surface")) + response = await llm.ainvoke(messages, config={"callbacks": callbacks}) return {"messages": [response]} diff --git a/examples/chat/python/tests/test_streaming_smoke.py b/examples/chat/python/tests/test_streaming_smoke.py new file mode 100644 index 000000000..8dee956c0 --- /dev/null +++ b/examples/chat/python/tests/test_streaming_smoke.py @@ -0,0 +1,52 @@ +"""Integration smoke: the generate node, when invoked with a canned chat-model +stream, dispatches a2ui-partial events and the final state has the +single-bubble shape from PR #255.""" +import json +from unittest.mock import patch, AsyncMock + +import pytest +from langchain_core.messages import AIMessageChunk + +from src.streaming.a2ui_partial_handler import A2uiPartialHandler + + +def _make_canned_stream() -> list[AIMessageChunk]: + """Five chunks of growing args for one tool_call to render_a2ui_surface.""" + return [ + AIMessageChunk(content="", tool_call_chunks=[{ + "id": "tc-1", "name": "render_a2ui_surface", "index": 0, + "args": '{"envelopes":[', + }]), + AIMessageChunk(content="", tool_call_chunks=[{ + "id": "tc-1", "name": "render_a2ui_surface", "index": 0, + "args": '{"surfaceUpdate":{"surfaceId":"s","components":[{"id":"root","type":"text","props":{}}]}},', + }]), + AIMessageChunk(content="", tool_call_chunks=[{ + "id": "tc-1", "name": "render_a2ui_surface", "index": 0, + "args": '{"beginRendering":{"surfaceId":"s","root":"root"}},', + }]), + AIMessageChunk(content="", tool_call_chunks=[{ + "id": "tc-1", "name": "render_a2ui_surface", "index": 0, + "args": '{"dataModelUpdate":{"surfaceId":"s","contents":[{"key":"text","valueString":"hi"}]}}', + }]), + AIMessageChunk(content="", tool_call_chunks=[{ + "id": "tc-1", "name": "render_a2ui_surface", "index": 0, + "args": "]}", + }]), + ] + + +@pytest.mark.asyncio +async def test_handler_dispatches_per_chunk(): + """At least 3 a2ui-partial events fire as the canned stream advances.""" + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + for chunk in _make_canned_stream(): + await handler.on_chat_model_stream(chunk, run_id="r1") + assert mock.await_count >= 3 + # Last cumulative string is the full envelope JSON. + last = mock.await_args_list[-1].args[1] + assert last["tool_call_id"] == "tc-1" + body = json.loads(last["args_so_far"]) + assert "envelopes" in body + assert len(body["envelopes"]) == 3