Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions examples/chat/python/src/streaming/a2ui_partial_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,33 @@
`arguments` delta.

The handler covers both shapes. For each delta belonging to our target
tool name, it concatenates arguments per `tool_call_id` and dispatches
an `a2ui-partial` custom event when the cumulative string grows. The
frontend bridge (libs/chat partial-args-bridge) consumes these.
tool name, it concatenates arguments per `tool_call_id` and writes an
`a2ui-partial` event into LangGraph's custom stream via the writer
returned by `langgraph.config.get_stream_writer()`. The writer's
payload reaches the frontend under `stream_mode='custom'` as a
`{type: 'custom', data: ...}` SSE event; the partial-args bridge
consumes it.

LangGraph's `custom` stream mode is decoupled from langchain_core's
`adispatch_custom_event` — they're different layers. `adispatch_custom_event`
emits callback events visible only via `stream_mode='events'`; the
LangGraph writer is the canonical way to surface custom data on the
SDK's `custom` channel.
"""
from __future__ import annotations

from typing import Any
from uuid import UUID

from langchain_core.callbacks import AsyncCallbackHandler, adispatch_custom_event
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langgraph.config import get_stream_writer


class A2uiPartialHandler(AsyncCallbackHandler):
"""Track per-tool_call_id cumulative arguments; dispatch a2ui-partial
custom events when the cumulative string grows."""
"""Track per-tool_call_id cumulative arguments; write a2ui-partial
custom events to the LangGraph stream when the cumulative string
grows."""

def __init__(self, tool_name: str = "render_a2ui_surface") -> None:
super().__init__()
Expand Down Expand Up @@ -63,16 +74,16 @@ async def on_llm_new_token(

# Path 1: Chat Completions API — classic tool_call_chunks list.
for tc in getattr(message, "tool_call_chunks", None) or []:
await self._handle_classic_chunk(tc)
self._handle_classic_chunk(tc)

# Path 2: Responses API — content blocks with type='function_call'.
content = getattr(message, "content", None)
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "function_call":
await self._handle_responses_block(block)
self._handle_responses_block(block)

async def _handle_classic_chunk(self, tc: dict) -> None:
def _handle_classic_chunk(self, tc: dict) -> None:
"""Chat Completions delta: {id, name?, args?, index}."""
name = tc.get("name") or ""
call_id = tc.get("id")
Expand All @@ -84,9 +95,9 @@ async def _handle_classic_chunk(self, tc: dict) -> None:
# on every chunk in this shape).
if name and name != self._tool_name:
return
await self._dispatch_delta(call_id, delta)
self._dispatch_delta(call_id, delta)

async def _handle_responses_block(self, block: dict) -> None:
def _handle_responses_block(self, block: dict) -> None:
"""Responses API delta: {type, name?, call_id?, arguments, index}.

The first block for a call carries `name` and `call_id`. Subsequent
Expand All @@ -101,7 +112,7 @@ async def _handle_responses_block(self, block: dict) -> None:
if call_id:
# First block for this call. Validate target tool name.
if name and name != self._tool_name:
# Not our target — still remember the mapping is "ignore".
# Not our target — drop any prior mapping for this index.
if isinstance(index, int):
self._index_to_call_id.pop(index, None)
return
Expand All @@ -117,15 +128,26 @@ async def _handle_responses_block(self, block: dict) -> None:
# was for a non-target tool. Skip.
return

await self._dispatch_delta(call_id, delta)
self._dispatch_delta(call_id, delta)

async def _dispatch_delta(self, call_id: str, delta: str) -> None:
def _dispatch_delta(self, call_id: str, delta: str) -> None:
existing = self._buffers.get(call_id, "")
updated = existing + delta
if updated == existing:
return # No growth — suppress duplicate dispatch.
self._buffers[call_id] = updated
await adispatch_custom_event(
"a2ui-partial",
{"tool_call_id": call_id, "args_so_far": updated},
)
# `get_stream_writer()` is contextvar-scoped to the currently
# executing LangGraph node. Even though this handler runs deep
# inside the LLM's callback chain, the contextvar is inherited
# so the writer reaches the parent node's `custom` stream.
try:
writer = get_stream_writer()
except RuntimeError:
# No stream writer in this context — handler is being invoked
# outside a LangGraph stream run (e.g. in unit tests). Silently
# skip; tests can mock get_stream_writer to assert behavior.
return
writer({
"name": "a2ui-partial",
"data": {"tool_call_id": call_id, "args_so_far": updated},
})
136 changes: 72 additions & 64 deletions examples/chat/python/tests/test_a2ui_partial_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for A2uiPartialHandler — drives canned on_llm_new_token events."""
from unittest.mock import patch, AsyncMock
"""Tests for A2uiPartialHandler — drives canned on_llm_new_token events
and asserts on the data written to the LangGraph stream writer."""
from unittest.mock import patch, MagicMock
from uuid import uuid4

import pytest
Expand All @@ -18,21 +19,34 @@ def _make_chunk(tool_call_chunks: list[dict]) -> ChatGenerationChunk:
)


def _make_responses_chunk(blocks: list[dict]) -> ChatGenerationChunk:
"""Wrap Responses-API content blocks (gpt-5 family) in a ChatGenerationChunk."""
return ChatGenerationChunk(
text="",
message=AIMessageChunk(content=blocks),
)


class TestA2uiPartialHandler:
@pytest.mark.asyncio
async def test_dispatches_event_when_chunk_grows_args(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
chunk = _make_chunk([
{"id": "tc-1", "name": "render_a2ui_surface", "args": "{\"envelopes\":[", "index": 0},
])
await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4())
mock.assert_awaited_once_with("a2ui-partial", {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["})
writer.assert_called_once_with({
"name": "a2ui-partial",
"data": {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["},
})

@pytest.mark.asyncio
async def test_concatenates_args_across_chunks_same_tool_call_id(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
Expand All @@ -43,29 +57,28 @@ async def test_concatenates_args_across_chunks_same_tool_call_id(self):
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]),
run_id=uuid4(),
)
# Second dispatch carries the cumulative string.
assert mock.await_count == 2
args = [call.args for call in mock.await_args_list]
assert args[0][1]["args_so_far"] == "{"
assert args[1][1]["args_so_far"] == "{\"x\":1}"
assert writer.call_count == 2
args = [c.args[0] for c in writer.call_args_list]
assert args[0]["data"]["args_so_far"] == "{"
assert args[1]["data"]["args_so_far"] == "{\"x\":1}"

@pytest.mark.asyncio
async def test_ignores_chunks_for_unrelated_tools(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]),
run_id=uuid4(),
)
mock.assert_not_awaited()
writer.assert_not_called()

@pytest.mark.asyncio
async def test_no_dispatch_when_args_did_not_grow(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
# First chunk grows the buffer; second chunk has empty args delta
# (a no-op chunk from the model) and must not re-dispatch.
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
Expand All @@ -76,12 +89,13 @@ async def test_no_dispatch_when_args_did_not_grow(self):
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]),
run_id=uuid4(),
)
mock.assert_awaited_once()
writer.assert_called_once()

@pytest.mark.asyncio
async def test_per_tool_call_id_state_isolation(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
await handler.on_llm_new_token(
"",
chunk=_make_chunk([
Expand All @@ -90,70 +104,64 @@ async def test_per_tool_call_id_state_isolation(self):
]),
run_id=uuid4(),
)
assert mock.await_count == 2
ids = {call.args[1]["tool_call_id"] for call in mock.await_args_list}
assert writer.call_count == 2
ids = {c.args[0]["data"]["tool_call_id"] for c in writer.call_args_list}
assert ids == {"tc-A", "tc-B"}

@pytest.mark.asyncio
async def test_ignores_token_event_without_chunk_message(self):
"""Some emitters of on_llm_new_token may pass chunk=None (legacy LLM
path). Handler must silently skip — no crash, no dispatch."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
await handler.on_llm_new_token("some token", chunk=None, run_id=uuid4())
mock.assert_not_awaited()
writer.assert_not_called()

@pytest.mark.asyncio
async def test_silently_skips_when_no_stream_writer_context(self):
"""When invoked outside a LangGraph stream context (e.g. raw script),
get_stream_writer() raises RuntimeError. Handler should swallow."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", side_effect=RuntimeError("not in stream")):
# Must not raise.
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
run_id=uuid4(),
)

@pytest.mark.asyncio
async def test_responses_api_function_call_content_blocks(self):
"""gpt-5 / Responses API streams tool-call deltas as content blocks
of type='function_call' rather than tool_call_chunks. The first
block for each call carries name + call_id; subsequent blocks for
the same index carry only the args delta."""
of type='function_call'. The first block carries name + call_id;
subsequent blocks for the same index carry only args delta."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
first = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
await handler.on_llm_new_token("", chunk=_make_responses_chunk([
{"type": "function_call", "name": "render_a2ui_surface",
"call_id": "call_ABC", "id": "fc_1",
"arguments": "", "index": 1},
]),
)
next1 = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
"call_id": "call_ABC", "id": "fc_1", "arguments": "", "index": 1},
]), run_id=uuid4())
await handler.on_llm_new_token("", chunk=_make_responses_chunk([
{"type": "function_call", "arguments": "{\"en", "index": 1},
]),
)
next2 = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
]), run_id=uuid4())
await handler.on_llm_new_token("", chunk=_make_responses_chunk([
{"type": "function_call", "arguments": "velopes", "index": 1},
]),
)
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_llm_new_token("", chunk=first, run_id=uuid4())
await handler.on_llm_new_token("", chunk=next1, run_id=uuid4())
await handler.on_llm_new_token("", chunk=next2, run_id=uuid4())
# First block has empty args — no dispatch. Two subsequent grow the buffer.
assert mock.await_count == 2
call1 = mock.await_args_list[0].args[1]
call2 = mock.await_args_list[1].args[1]
assert call1["tool_call_id"] == "call_ABC"
assert call1["args_so_far"] == "{\"en"
assert call2["tool_call_id"] == "call_ABC"
assert call2["args_so_far"] == "{\"envelopes"
]), run_id=uuid4())
# First block has empty args — no dispatch. Two subsequent grow.
assert writer.call_count == 2
first = writer.call_args_list[0].args[0]
second = writer.call_args_list[1].args[0]
assert first["data"]["tool_call_id"] == "call_ABC"
assert first["data"]["args_so_far"] == "{\"en"
assert second["data"]["args_so_far"] == "{\"envelopes"

@pytest.mark.asyncio
async def test_responses_api_ignores_non_target_tools(self):
"""Responses-API content blocks for a different tool are ignored."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
chunk = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
await handler.on_llm_new_token("", chunk=_make_responses_chunk([
{"type": "function_call", "name": "search_documents",
"call_id": "call_X", "arguments": "{}", "index": 0},
]),
)
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4())
mock.assert_not_awaited()
]), run_id=uuid4())
writer.assert_not_called()
24 changes: 13 additions & 11 deletions examples/chat/python/tests/test_streaming_smoke.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Integration smoke: when fed a canned chat-model token stream, the
A2uiPartialHandler dispatches at least 3 a2ui-partial custom events
carrying the cumulative tool_call.args string."""
A2uiPartialHandler writes at least 3 a2ui-partial entries to the
LangGraph stream writer carrying the cumulative tool_call.args string."""
import json
from unittest.mock import patch, AsyncMock
from unittest.mock import patch, MagicMock
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -36,16 +36,18 @@ def _make_canned_stream() -> list[ChatGenerationChunk]:


@pytest.mark.asyncio
async def test_handler_dispatches_per_chunk():
"""At least 3 a2ui-partial events fire as the canned stream advances."""
async def test_handler_writes_per_chunk():
"""At least 3 a2ui-partial entries written to the stream as the canned
stream advances."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
writer = MagicMock()
with patch("src.streaming.a2ui_partial_handler.get_stream_writer", return_value=writer):
for chunk in _make_canned_stream():
await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4())
assert mock.await_count >= 3
# Last cumulative string is the full envelope JSON.
last = mock.await_args_list[-1].args[1]
assert last["tool_call_id"] == "tc-1"
body = json.loads(last["args_so_far"])
assert writer.call_count >= 3
last = writer.call_args_list[-1].args[0]
assert last["name"] == "a2ui-partial"
assert last["data"]["tool_call_id"] == "tc-1"
body = json.loads(last["data"]["args_so_far"])
assert "envelopes" in body
assert len(body["envelopes"]) == 3
Loading