Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions examples/chat/python/src/streaming/a2ui_partial_handler.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,54 @@
# SPDX-License-Identifier: MIT
"""Streaming callback handler that sidebands a parent LLM's growing
tool_call.arguments as A2UI-partial custom events. Listens to LangChain's
on_chat_model_stream events; per tool_call_id, concatenates argument
on_llm_new_token events (the canonical streaming-token callback for
both chat models and legacy LLMs); per tool_call_id, concatenates argument
deltas and dispatches each cumulative state via adispatch_custom_event.
The frontend bridge (libs/chat partial-args-bridge) consumes these.
"""
from __future__ import annotations

from typing import Any
from uuid import UUID

from langchain_core.callbacks import AsyncCallbackHandler, adispatch_custom_event
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk


class A2uiPartialHandler(AsyncCallbackHandler):
"""Track per-tool_call_id cumulative arguments; dispatch a2ui-partial
custom events when the cumulative string grows."""
custom events when the cumulative string grows.

Hooks into `on_llm_new_token` — the canonical streaming-token callback
fired by ChatOpenAI when `streaming=True` is enabled. For each chunk
we inspect the embedded `chunk.message.tool_call_chunks` list (only
populated when the LLM is mid-stream of a tool_call) and forward any
delta belonging to our target tool name to the frontend.
"""

def __init__(self, tool_name: str = "render_a2ui_surface") -> None:
super().__init__()
self._tool_name = tool_name
# tool_call_id -> cumulative args string
self._buffers: dict[str, str] = {}

async def on_chat_model_stream(
async def on_llm_new_token(
self,
chunk: Any,
token: str,
*,
run_id: str | None = None,
chunk: ChatGenerationChunk | GenerationChunk | None = None,
run_id: UUID | None = None,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
# `chunk` is an AIMessageChunk. Each chunk may carry multiple
# tool_call_chunks (e.g. interleaved across concurrent tool_calls).
tool_call_chunks = getattr(chunk, "tool_call_chunks", None) or []
# We only care about chat-model chunks that carry tool_call_chunks.
if chunk is None:
return
message = getattr(chunk, "message", None)
if message is None:
return
tool_call_chunks = getattr(message, "tool_call_chunks", None) or []
for tc in tool_call_chunks:
name = tc.get("name") or ""
call_id = tc.get("id")
Expand Down
68 changes: 47 additions & 21 deletions examples/chat/python/tests/test_a2ui_partial_handler.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,47 @@
"""Tests for A2uiPartialHandler — drives canned on_chat_model_stream events."""
"""Tests for A2uiPartialHandler — drives canned on_llm_new_token events."""
from unittest.mock import patch, AsyncMock
from uuid import uuid4

import pytest
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk

from src.streaming.a2ui_partial_handler import A2uiPartialHandler


def _make_chunk(tool_call_chunks: list[dict]) -> ChatGenerationChunk:
"""Wrap an AIMessageChunk in a ChatGenerationChunk the way the real
LangChain streaming callback path does."""
return ChatGenerationChunk(
text="",
message=AIMessageChunk(content="", tool_call_chunks=tool_call_chunks),
)


class TestA2uiPartialHandler:
@pytest.mark.asyncio
async def test_dispatches_event_when_chunk_grows_args(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
chunk = AIMessageChunk(content="", tool_call_chunks=[
chunk = _make_chunk([
{"id": "tc-1", "name": "render_a2ui_surface", "args": "{\"envelopes\":[", "index": 0},
])
await handler.on_chat_model_stream(chunk, run_id="r1")
await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4())
mock.assert_awaited_once_with("a2ui-partial", {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["})

@pytest.mark.asyncio
async def test_concatenates_args_across_chunks_same_tool_call_id(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
run_id="r1",
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
run_id=uuid4(),
)
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]),
run_id="r1",
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]),
run_id=uuid4(),
)
# Second dispatch carries the cumulative string.
assert mock.await_count == 2
Expand All @@ -40,9 +53,10 @@ async def test_concatenates_args_across_chunks_same_tool_call_id(self):
async def test_ignores_chunks_for_unrelated_tools(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]),
run_id="r1",
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]),
run_id=uuid4(),
)
mock.assert_not_awaited()

Expand All @@ -52,27 +66,39 @@ async def test_no_dispatch_when_args_did_not_grow(self):
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
# First chunk grows the buffer; second chunk has empty args delta
# (a no-op chunk from the model) and must not re-dispatch.
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
run_id="r1",
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
run_id=uuid4(),
)
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]),
run_id="r1",
await handler.on_llm_new_token(
"",
chunk=_make_chunk([{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]),
run_id=uuid4(),
)
mock.assert_awaited_once()

@pytest.mark.asyncio
async def test_per_tool_call_id_state_isolation(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[
await handler.on_llm_new_token(
"",
chunk=_make_chunk([
{"id": "tc-A", "name": "render_a2ui_surface", "args": "{", "index": 0},
{"id": "tc-B", "name": "render_a2ui_surface", "args": "[", "index": 1},
]),
run_id="r1",
run_id=uuid4(),
)
assert mock.await_count == 2
ids = {call.args[1]["tool_call_id"] for call in mock.await_args_list}
assert ids == {"tc-A", "tc-B"}

@pytest.mark.asyncio
async def test_ignores_token_event_without_chunk_message(self):
"""Some emitters of on_llm_new_token may pass chunk=None (legacy LLM
path). Handler must silently skip — no crash, no dispatch."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_llm_new_token("some token", chunk=None, run_id=uuid4())
mock.assert_not_awaited()
51 changes: 25 additions & 26 deletions examples/chat/python/tests/test_streaming_smoke.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
"""Integration smoke: the generate node, when invoked with a canned chat-model
stream, dispatches a2ui-partial events and the final state has the
single-bubble shape from PR #255."""
"""Integration smoke: when fed a canned chat-model token stream, the
A2uiPartialHandler dispatches at least 3 a2ui-partial custom events
carrying the cumulative tool_call.args string."""
import json
from unittest.mock import patch, AsyncMock
from uuid import uuid4

import pytest
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk

from src.streaming.a2ui_partial_handler import A2uiPartialHandler


def _make_canned_stream() -> list[AIMessageChunk]:
"""Five chunks of growing args for one tool_call to render_a2ui_surface."""
def _make_canned_stream() -> list[ChatGenerationChunk]:
"""Five chunks of growing args for one tool_call to render_a2ui_surface,
wrapped in ChatGenerationChunk as LangChain emits them from the
on_llm_new_token callback."""
deltas = [
'{"envelopes":[',
'{"surfaceUpdate":{"surfaceId":"s","components":[{"id":"root","type":"text","props":{}}]}},',
'{"beginRendering":{"surfaceId":"s","root":"root"}},',
'{"dataModelUpdate":{"surfaceId":"s","contents":[{"key":"text","valueString":"hi"}]}}',
"]}",
]
return [
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"envelopes":[',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"surfaceUpdate":{"surfaceId":"s","components":[{"id":"root","type":"text","props":{}}]}},',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"beginRendering":{"surfaceId":"s","root":"root"}},',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"dataModelUpdate":{"surfaceId":"s","contents":[{"key":"text","valueString":"hi"}]}}',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": "]}",
}]),
ChatGenerationChunk(
text="",
message=AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": delta,
}]),
)
for delta in deltas
]


Expand All @@ -42,7 +41,7 @@ async def test_handler_dispatches_per_chunk():
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
for chunk in _make_canned_stream():
await handler.on_chat_model_stream(chunk, run_id="r1")
await handler.on_llm_new_token("", chunk=chunk, run_id=uuid4())
assert mock.await_count >= 3
# Last cumulative string is the full envelope JSON.
last = mock.await_args_list[-1].args[1]
Expand Down
Loading