diff --git a/examples/chat/python/src/streaming/a2ui_partial_handler.py b/examples/chat/python/src/streaming/a2ui_partial_handler.py index 378b0ec53..7b08a606f 100644 --- a/examples/chat/python/src/streaming/a2ui_partial_handler.py +++ b/examples/chat/python/src/streaming/a2ui_partial_handler.py @@ -17,22 +17,33 @@ `arguments` delta. The handler covers both shapes. For each delta belonging to our target -tool name, it concatenates arguments per `tool_call_id` and dispatches -an `a2ui-partial` custom event when the cumulative string grows. The -frontend bridge (libs/chat partial-args-bridge) consumes these. +tool name, it concatenates arguments per `tool_call_id` and writes an +`a2ui-partial` event into LangGraph's custom stream via the writer +returned by `langgraph.config.get_stream_writer()`. The writer's +payload reaches the frontend under `stream_mode='custom'` as a +`{type: 'custom', data: ...}` SSE event; the partial-args bridge +consumes it. + +LangGraph's `custom` stream mode is decoupled from langchain_core's +`adispatch_custom_event` — they're different layers. `adispatch_custom_event` +emits callback events visible only via `stream_mode='events'`; the +LangGraph writer is the canonical way to surface custom data on the +SDK's `custom` channel. """ from __future__ import annotations from typing import Any from uuid import UUID -from langchain_core.callbacks import AsyncCallbackHandler, adispatch_custom_event +from langchain_core.callbacks import AsyncCallbackHandler from langchain_core.outputs import ChatGenerationChunk, GenerationChunk +from langgraph.config import get_stream_writer class A2uiPartialHandler(AsyncCallbackHandler): - """Track per-tool_call_id cumulative arguments; dispatch a2ui-partial - custom events when the cumulative string grows.""" + """Track per-tool_call_id cumulative arguments; write a2ui-partial + custom events to the LangGraph stream when the cumulative string + grows.""" def __init__(self, tool_name: str = "render_a2ui_surface") -> None: super().__init__() @@ -63,16 +74,16 @@ async def on_llm_new_token( # Path 1: Chat Completions API — classic tool_call_chunks list. for tc in getattr(message, "tool_call_chunks", None) or []: - await self._handle_classic_chunk(tc) + self._handle_classic_chunk(tc) # Path 2: Responses API — content blocks with type='function_call'. content = getattr(message, "content", None) if isinstance(content, list): for block in content: if isinstance(block, dict) and block.get("type") == "function_call": - await self._handle_responses_block(block) + self._handle_responses_block(block) - async def _handle_classic_chunk(self, tc: dict) -> None: + def _handle_classic_chunk(self, tc: dict) -> None: """Chat Completions delta: {id, name?, args?, index}.""" name = tc.get("name") or "" call_id = tc.get("id") @@ -84,9 +95,9 @@ async def _handle_classic_chunk(self, tc: dict) -> None: # on every chunk in this shape). if name and name != self._tool_name: return - await self._dispatch_delta(call_id, delta) + self._dispatch_delta(call_id, delta) - async def _handle_responses_block(self, block: dict) -> None: + def _handle_responses_block(self, block: dict) -> None: """Responses API delta: {type, name?, call_id?, arguments, index}. The first block for a call carries `name` and `call_id`. Subsequent @@ -101,7 +112,7 @@ async def _handle_responses_block(self, block: dict) -> None: if call_id: # First block for this call. Validate target tool name. if name and name != self._tool_name: - # Not our target — still remember the mapping is "ignore". + # Not our target — drop any prior mapping for this index. if isinstance(index, int): self._index_to_call_id.pop(index, None) return @@ -117,15 +128,26 @@ async def _handle_responses_block(self, block: dict) -> None: # was for a non-target tool. Skip. return - await self._dispatch_delta(call_id, delta) + self._dispatch_delta(call_id, delta) - async def _dispatch_delta(self, call_id: str, delta: str) -> None: + def _dispatch_delta(self, call_id: str, delta: str) -> None: existing = self._buffers.get(call_id, "") updated = existing + delta if updated == existing: return # No growth — suppress duplicate dispatch. self._buffers[call_id] = updated - await adispatch_custom_event( - "a2ui-partial", - {"tool_call_id": call_id, "args_so_far": updated}, - ) + # `get_stream_writer()` is contextvar-scoped to the currently + # executing LangGraph node. Even though this handler runs deep + # inside the LLM's callback chain, the contextvar is inherited + # so the writer reaches the parent node's `custom` stream. + try: + writer = get_stream_writer() + except RuntimeError: + # No stream writer in this context — handler is being invoked + # outside a LangGraph stream run (e.g. in unit tests). Silently + # skip; tests can mock get_stream_writer to assert behavior. + return + writer({ + "name": "a2ui-partial", + "data": {"tool_call_id": call_id, "args_so_far": updated}, + }) diff --git a/examples/chat/python/tests/test_a2ui_partial_handler.py b/examples/chat/python/tests/test_a2ui_partial_handler.py index 64d3858b5..3b7e1f729 100644 --- a/examples/chat/python/tests/test_a2ui_partial_handler.py +++ b/examples/chat/python/tests/test_a2ui_partial_handler.py @@ -1,5 +1,6 @@ -"""Tests for A2uiPartialHandler — drives canned on_llm_new_token events.""" -from unittest.mock import patch, AsyncMock +"""Tests for A2uiPartialHandler — drives canned on_llm_new_token events +and asserts on the data written to the LangGraph stream writer.""" +from unittest.mock import patch, MagicMock from uuid import uuid4 import pytest @@ -18,21 +19,34 @@ def _make_chunk(tool_call_chunks: list[dict]) -> ChatGenerationChunk: ) +def _make_responses_chunk(blocks: list[dict]) -> ChatGenerationChunk: + """Wrap Responses-API content blocks (gpt-5 family) in a ChatGenerationChunk.""" + return ChatGenerationChunk( + text="", + message=AIMessageChunk(content=blocks), + ) + + class TestA2uiPartialHandler: @pytest.mark.asyncio async def test_dispatches_event_when_chunk_grows_args(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): chunk = _make_chunk([ {"id": "tc-1", "name": "render_a2ui_surface", "args": "{\"envelopes\":[", "index": 0}, ]) await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4()) - mock.assert_awaited_once_with("a2ui-partial", {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["}) + writer.assert_called_once_with({ + "name": "a2ui-partial", + "data": {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["}, + }) @pytest.mark.asyncio async def test_concatenates_args_across_chunks_same_tool_call_id(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): await handler.on_llm_new_token( "", chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), @@ -43,29 +57,28 @@ async def test_concatenates_args_across_chunks_same_tool_call_id(self): chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]), run_id=uuid4(), ) - # Second dispatch carries the cumulative string. - assert mock.await_count == 2 - args = [call.args for call in mock.await_args_list] - assert args[0][1]["args_so_far"] == "{" - assert args[1][1]["args_so_far"] == "{\"x\":1}" + assert writer.call_count == 2 + args = [c.args[0] for c in writer.call_args_list] + assert args[0]["data"]["args_so_far"] == "{" + assert args[1]["data"]["args_so_far"] == "{\"x\":1}" @pytest.mark.asyncio async def test_ignores_chunks_for_unrelated_tools(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): await handler.on_llm_new_token( "", chunk=_make_chunk([{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]), run_id=uuid4(), ) - mock.assert_not_awaited() + writer.assert_not_called() @pytest.mark.asyncio async def test_no_dispatch_when_args_did_not_grow(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: - # First chunk grows the buffer; second chunk has empty args delta - # (a no-op chunk from the model) and must not re-dispatch. + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): await handler.on_llm_new_token( "", chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), @@ -76,12 +89,13 @@ async def test_no_dispatch_when_args_did_not_grow(self): chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]), run_id=uuid4(), ) - mock.assert_awaited_once() + writer.assert_called_once() @pytest.mark.asyncio async def test_per_tool_call_id_state_isolation(self): handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): await handler.on_llm_new_token( "", chunk=_make_chunk([ @@ -90,70 +104,64 @@ async def test_per_tool_call_id_state_isolation(self): ]), run_id=uuid4(), ) - assert mock.await_count == 2 - ids = {call.args[1]["tool_call_id"] for call in mock.await_args_list} + assert writer.call_count == 2 + ids = {c.args[0]["data"]["tool_call_id"] for c in writer.call_args_list} assert ids == {"tc-A", "tc-B"} @pytest.mark.asyncio async def test_ignores_token_event_without_chunk_message(self): - """Some emitters of on_llm_new_token may pass chunk=None (legacy LLM - path). Handler must silently skip — no crash, no dispatch.""" handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): await handler.on_llm_new_token("some token", chunk=None, run_id=uuid4()) - mock.assert_not_awaited() + writer.assert_not_called() + + @pytest.mark.asyncio + async def test_silently_skips_when_no_stream_writer_context(self): + """When invoked outside a LangGraph stream context (e.g. raw script), + get_stream_writer() raises RuntimeError. Handler should swallow.""" + handler = A2uiPartialHandler(tool_name="render_a2ui_surface") + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", side_effect=RuntimeError("not in stream")): + # Must not raise. + await handler.on_llm_new_token( + "", + chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]), + run_id=uuid4(), + ) @pytest.mark.asyncio async def test_responses_api_function_call_content_blocks(self): """gpt-5 / Responses API streams tool-call deltas as content blocks - of type='function_call' rather than tool_call_chunks. The first - block for each call carries name + call_id; subsequent blocks for - the same index carry only the args delta.""" + of type='function_call'. The first block carries name + call_id; + subsequent blocks for the same index carry only args delta.""" handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - first = ChatGenerationChunk( - text="", - message=AIMessageChunk(content=[ + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): + await handler.on_llm_new_token("", chunk=_make_responses_chunk([ {"type": "function_call", "name": "render_a2ui_surface", - "call_id": "call_ABC", "id": "fc_1", - "arguments": "", "index": 1}, - ]), - ) - next1 = ChatGenerationChunk( - text="", - message=AIMessageChunk(content=[ + "call_id": "call_ABC", "id": "fc_1", "arguments": "", "index": 1}, + ]), run_id=uuid4()) + await handler.on_llm_new_token("", chunk=_make_responses_chunk([ {"type": "function_call", "arguments": "{\"en", "index": 1}, - ]), - ) - next2 = ChatGenerationChunk( - text="", - message=AIMessageChunk(content=[ + ]), run_id=uuid4()) + await handler.on_llm_new_token("", chunk=_make_responses_chunk([ {"type": "function_call", "arguments": "velopes", "index": 1}, - ]), - ) - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: - await handler.on_llm_new_token("", chunk=first, run_id=uuid4()) - await handler.on_llm_new_token("", chunk=next1, run_id=uuid4()) - await handler.on_llm_new_token("", chunk=next2, run_id=uuid4()) - # First block has empty args — no dispatch. Two subsequent grow the buffer. - assert mock.await_count == 2 - call1 = mock.await_args_list[0].args[1] - call2 = mock.await_args_list[1].args[1] - assert call1["tool_call_id"] == "call_ABC" - assert call1["args_so_far"] == "{\"en" - assert call2["tool_call_id"] == "call_ABC" - assert call2["args_so_far"] == "{\"envelopes" + ]), run_id=uuid4()) + # First block has empty args — no dispatch. Two subsequent grow. + assert writer.call_count == 2 + first = writer.call_args_list[0].args[0] + second = writer.call_args_list[1].args[0] + assert first["data"]["tool_call_id"] == "call_ABC" + assert first["data"]["args_so_far"] == "{\"en" + assert second["data"]["args_so_far"] == "{\"envelopes" @pytest.mark.asyncio async def test_responses_api_ignores_non_target_tools(self): - """Responses-API content blocks for a different tool are ignored.""" handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - chunk = ChatGenerationChunk( - text="", - message=AIMessageChunk(content=[ + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): + await handler.on_llm_new_token("", chunk=_make_responses_chunk([ {"type": "function_call", "name": "search_documents", "call_id": "call_X", "arguments": "{}", "index": 0}, - ]), - ) - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: - await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4()) - mock.assert_not_awaited() + ]), run_id=uuid4()) + writer.assert_not_called() diff --git a/examples/chat/python/tests/test_streaming_smoke.py b/examples/chat/python/tests/test_streaming_smoke.py index 04e3b9795..960154bb7 100644 --- a/examples/chat/python/tests/test_streaming_smoke.py +++ b/examples/chat/python/tests/test_streaming_smoke.py @@ -1,8 +1,8 @@ """Integration smoke: when fed a canned chat-model token stream, the -A2uiPartialHandler dispatches at least 3 a2ui-partial custom events -carrying the cumulative tool_call.args string.""" +A2uiPartialHandler writes at least 3 a2ui-partial entries to the +LangGraph stream writer carrying the cumulative tool_call.args string.""" import json -from unittest.mock import patch, AsyncMock +from unittest.mock import patch, MagicMock from uuid import uuid4 import pytest @@ -36,16 +36,18 @@ def _make_canned_stream() -> list[ChatGenerationChunk]: @pytest.mark.asyncio -async def test_handler_dispatches_per_chunk(): - """At least 3 a2ui-partial events fire as the canned stream advances.""" +async def test_handler_writes_per_chunk(): + """At least 3 a2ui-partial entries written to the stream as the canned + stream advances.""" handler = A2uiPartialHandler(tool_name="render_a2ui_surface") - with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock: + writer = MagicMock() + with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer): for chunk in _make_canned_stream(): await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4()) - assert mock.await_count >= 3 - # Last cumulative string is the full envelope JSON. - last = mock.await_args_list[-1].args[1] - assert last["tool_call_id"] == "tc-1" - body = json.loads(last["args_so_far"]) + assert writer.call_count >= 3 + last = writer.call_args_list[-1].args[0] + assert last["name"] == "a2ui-partial" + assert last["data"]["tool_call_id"] == "tc-1" + body = json.loads(last["data"]["args_so_far"]) assert "envelopes" in body assert len(body["envelopes"]) == 3