Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 95 additions & 31 deletions examples/chat/python/src/streaming/a2ui_partial_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
# SPDX-License-Identifier: MIT
"""Streaming callback handler that sidebands a parent LLM's growing
tool_call.arguments as A2UI-partial custom events. Listens to LangChain's
on_llm_new_token events (the canonical streaming-token callback for
both chat models and legacy LLMs); per tool_call_id, concatenates argument
deltas and dispatches each cumulative state via adispatch_custom_event.
The frontend bridge (libs/chat partial-args-bridge) consumes these.
tool_call.arguments as A2UI-partial custom events.

Hooks into LangChain's `on_llm_new_token` callback (the canonical
streaming-token callback fired by ChatOpenAI when streaming=True is
enabled). Two delta shapes are observed in production depending on
which OpenAI API the model uses:

- Chat Completions API: `chunk.message.tool_call_chunks` carries the
classic `[{name, id, args, index}, ...]` deltas.
- Responses API (gpt-5 family): `chunk.message.content` carries a
list of content blocks; tool-call deltas appear as
`{type: 'function_call', name?: str, call_id?: str, arguments: str,
index: int}`. The first block for each call carries `name` and
`call_id`; subsequent blocks for the same `index` carry only the
`arguments` delta.

The handler covers both shapes. For each delta belonging to our target
tool name, it concatenates arguments per `tool_call_id` and dispatches
an `a2ui-partial` custom event when the cumulative string grows. The
frontend bridge (libs/chat partial-args-bridge) consumes these.
"""
from __future__ import annotations

Expand All @@ -17,20 +32,18 @@

class A2uiPartialHandler(AsyncCallbackHandler):
"""Track per-tool_call_id cumulative arguments; dispatch a2ui-partial
custom events when the cumulative string grows.

Hooks into `on_llm_new_token` — the canonical streaming-token callback
fired by ChatOpenAI when `streaming=True` is enabled. For each chunk
we inspect the embedded `chunk.message.tool_call_chunks` list (only
populated when the LLM is mid-stream of a tool_call) and forward any
delta belonging to our target tool name to the frontend.
"""
custom events when the cumulative string grows."""

def __init__(self, tool_name: str = "render_a2ui_surface") -> None:
super().__init__()
self._tool_name = tool_name
# tool_call_id -> cumulative args string
self._buffers: dict[str, str] = {}
# Responses-API index -> tool_call_id mapping. The first content
# block for each tool call carries name + call_id; subsequent
# blocks for that `index` carry only the args delta, so we
# remember the mapping to attribute later deltas correctly.
self._index_to_call_id: dict[int, str] = {}

async def on_llm_new_token(
self,
Expand All @@ -42,26 +55,77 @@ async def on_llm_new_token(
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
# We only care about chat-model chunks that carry tool_call_chunks.
if chunk is None:
return
message = getattr(chunk, "message", None)
if message is None:
return
tool_call_chunks = getattr(message, "tool_call_chunks", None) or []
for tc in tool_call_chunks:
name = tc.get("name") or ""
call_id = tc.get("id")
delta = tc.get("args") or ""
if name != self._tool_name or not call_id:
continue
existing = self._buffers.get(call_id, "")
updated = existing + delta
if updated == existing:
# No growth — don't re-dispatch the same payload.
continue
self._buffers[call_id] = updated
await adispatch_custom_event(
"a2ui-partial",
{"tool_call_id": call_id, "args_so_far": updated},
)

# Path 1: Chat Completions API — classic tool_call_chunks list.
for tc in getattr(message, "tool_call_chunks", None) or []:
await self._handle_classic_chunk(tc)

# Path 2: Responses API — content blocks with type='function_call'.
content = getattr(message, "content", None)
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "function_call":
await self._handle_responses_block(block)

async def _handle_classic_chunk(self, tc: dict) -> None:
"""Chat Completions delta: {id, name?, args?, index}."""
name = tc.get("name") or ""
call_id = tc.get("id")
delta = tc.get("args") or ""
if not call_id:
return
# The first chunk for a call carries the name; later chunks may omit
# it but should be attributed to the same call_id (which IS carried
# on every chunk in this shape).
if name and name != self._tool_name:
return
await self._dispatch_delta(call_id, delta)

async def _handle_responses_block(self, block: dict) -> None:
"""Responses API delta: {type, name?, call_id?, arguments, index}.

The first block for a call carries `name` and `call_id`. Subsequent
blocks for the same `index` carry only the `arguments` delta — we
remember the index→call_id mapping to attribute them correctly.
"""
index = block.get("index")
name = block.get("name") or ""
call_id = block.get("call_id")
delta = block.get("arguments") or ""

if call_id:
# First block for this call. Validate target tool name.
if name and name != self._tool_name:
# Not our target — still remember the mapping is "ignore".
if isinstance(index, int):
self._index_to_call_id.pop(index, None)
return
if isinstance(index, int):
self._index_to_call_id[index] = call_id
else:
# Subsequent block — look up call_id by index.
if not isinstance(index, int):
return
call_id = self._index_to_call_id.get(index)
if not call_id:
# Either we never saw a first-block for this index, or it
# was for a non-target tool. Skip.
return

await self._dispatch_delta(call_id, delta)

async def _dispatch_delta(self, call_id: str, delta: str) -> None:
existing = self._buffers.get(call_id, "")
updated = existing + delta
if updated == existing:
return # No growth — suppress duplicate dispatch.
self._buffers[call_id] = updated
await adispatch_custom_event(
"a2ui-partial",
{"tool_call_id": call_id, "args_so_far": updated},
)
55 changes: 55 additions & 0 deletions examples/chat/python/tests/test_a2ui_partial_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,58 @@ async def test_ignores_token_event_without_chunk_message(self):
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_llm_new_token("some token", chunk=None, run_id=uuid4())
mock.assert_not_awaited()

@pytest.mark.asyncio
async def test_responses_api_function_call_content_blocks(self):
"""gpt-5 / Responses API streams tool-call deltas as content blocks
of type='function_call' rather than tool_call_chunks. The first
block for each call carries name + call_id; subsequent blocks for
the same index carry only the args delta."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
first = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
{"type": "function_call", "name": "render_a2ui_surface",
"call_id": "call_ABC", "id": "fc_1",
"arguments": "", "index": 1},
]),
)
next1 = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
{"type": "function_call", "arguments": "{\"en", "index": 1},
]),
)
next2 = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
{"type": "function_call", "arguments": "velopes", "index": 1},
]),
)
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_llm_new_token("", chunk=first, run_id=uuid4())
await handler.on_llm_new_token("", chunk=next1, run_id=uuid4())
await handler.on_llm_new_token("", chunk=next2, run_id=uuid4())
# First block has empty args — no dispatch. Two subsequent grow the buffer.
assert mock.await_count == 2
call1 = mock.await_args_list[0].args[1]
call2 = mock.await_args_list[1].args[1]
assert call1["tool_call_id"] == "call_ABC"
assert call1["args_so_far"] == "{\"en"
assert call2["tool_call_id"] == "call_ABC"
assert call2["args_so_far"] == "{\"envelopes"

@pytest.mark.asyncio
async def test_responses_api_ignores_non_target_tools(self):
"""Responses-API content blocks for a different tool are ignored."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
chunk = ChatGenerationChunk(
text="",
message=AIMessageChunk(content=[
{"type": "function_call", "name": "search_documents",
"call_id": "call_X", "arguments": "{}", "index": 0},
]),
)
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4())
mock.assert_not_awaited()
Loading