Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/chat/python/src/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from langchain_core.tools import tool
from langgraph_sdk import get_client

from src.streaming.a2ui_partial_handler import A2uiPartialHandler
from src.streaming.envelope_tool import render_a2ui_surface
from src.streaming.envelope_normalizer import normalize_envelope_args
from src.schemas.a2ui_v1 import A2UI_V1_SCHEMA_PROMPT
Expand Down Expand Up @@ -410,7 +411,14 @@ async def generate(state: State, config: RunnableConfig) -> dict:
"entries. This lets the client mount the surface as early as possible."
)
messages = [SystemMessage(content=system)] + state["messages"]
response = await llm.ainvoke(messages)
# When in a2ui mode, attach the partial-envelope sideband handler so
# the parent LLM's tool_call_chunks for render_a2ui_surface are
# dispatched as a2ui-partial custom events (consumed by the frontend
# partial-args bridge).
callbacks = []
if gen_ui_mode == "a2ui":
callbacks.append(A2uiPartialHandler(tool_name="render_a2ui_surface"))
response = await llm.ainvoke(messages, config={"callbacks": callbacks})
return {"messages": [response]}


Expand Down
50 changes: 50 additions & 0 deletions examples/chat/python/src/streaming/a2ui_partial_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: MIT
"""Streaming callback handler that sidebands a parent LLM's growing
tool_call.arguments as A2UI-partial custom events. Listens to LangChain's
on_chat_model_stream events; per tool_call_id, concatenates argument
deltas and dispatches each cumulative state via adispatch_custom_event.
The frontend bridge (libs/chat partial-args-bridge) consumes these.
"""
from __future__ import annotations

from typing import Any

from langchain_core.callbacks import AsyncCallbackHandler, adispatch_custom_event


class A2uiPartialHandler(AsyncCallbackHandler):
"""Track per-tool_call_id cumulative arguments; dispatch a2ui-partial
custom events when the cumulative string grows."""

def __init__(self, tool_name: str = "render_a2ui_surface") -> None:
super().__init__()
self._tool_name = tool_name
# tool_call_id -> cumulative args string
self._buffers: dict[str, str] = {}

async def on_chat_model_stream(
self,
chunk: Any,
*,
run_id: str | None = None,
**kwargs: Any,
) -> None:
# `chunk` is an AIMessageChunk. Each chunk may carry multiple
# tool_call_chunks (e.g. interleaved across concurrent tool_calls).
tool_call_chunks = getattr(chunk, "tool_call_chunks", None) or []
for tc in tool_call_chunks:
name = tc.get("name") or ""
call_id = tc.get("id")
delta = tc.get("args") or ""
if name != self._tool_name or not call_id:
continue
existing = self._buffers.get(call_id, "")
updated = existing + delta
if updated == existing:
# No growth — don't re-dispatch the same payload.
continue
self._buffers[call_id] = updated
await adispatch_custom_event(
"a2ui-partial",
{"tool_call_id": call_id, "args_so_far": updated},
)
78 changes: 78 additions & 0 deletions examples/chat/python/tests/test_a2ui_partial_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Tests for A2uiPartialHandler — drives canned on_chat_model_stream events."""
from unittest.mock import patch, AsyncMock

import pytest
from langchain_core.messages import AIMessageChunk

from src.streaming.a2ui_partial_handler import A2uiPartialHandler


class TestA2uiPartialHandler:
@pytest.mark.asyncio
async def test_dispatches_event_when_chunk_grows_args(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
chunk = AIMessageChunk(content="", tool_call_chunks=[
{"id": "tc-1", "name": "render_a2ui_surface", "args": "{\"envelopes\":[", "index": 0},
])
await handler.on_chat_model_stream(chunk, run_id="r1")
mock.assert_awaited_once_with("a2ui-partial", {"tool_call_id": "tc-1", "args_so_far": "{\"envelopes\":["})

@pytest.mark.asyncio
async def test_concatenates_args_across_chunks_same_tool_call_id(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
run_id="r1",
)
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "\"x\":1}", "index": 0}]),
run_id="r1",
)
# Second dispatch carries the cumulative string.
assert mock.await_count == 2
args = [call.args for call in mock.await_args_list]
assert args[0][1]["args_so_far"] == "{"
assert args[1][1]["args_so_far"] == "{\"x\":1}"

@pytest.mark.asyncio
async def test_ignores_chunks_for_unrelated_tools(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-x", "name": "search_documents", "args": "x", "index": 0}]),
run_id="r1",
)
mock.assert_not_awaited()

@pytest.mark.asyncio
async def test_no_dispatch_when_args_did_not_grow(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
# First chunk grows the buffer; second chunk has empty args delta
# (a no-op chunk from the model) and must not re-dispatch.
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "{", "index": 0}]),
run_id="r1",
)
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[{"id": "tc-1", "name": "render_a2ui_surface", "args": "", "index": 0}]),
run_id="r1",
)
mock.assert_awaited_once()

@pytest.mark.asyncio
async def test_per_tool_call_id_state_isolation(self):
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
await handler.on_chat_model_stream(
AIMessageChunk(content="", tool_call_chunks=[
{"id": "tc-A", "name": "render_a2ui_surface", "args": "{", "index": 0},
{"id": "tc-B", "name": "render_a2ui_surface", "args": "[", "index": 1},
]),
run_id="r1",
)
assert mock.await_count == 2
ids = {call.args[1]["tool_call_id"] for call in mock.await_args_list}
assert ids == {"tc-A", "tc-B"}
52 changes: 52 additions & 0 deletions examples/chat/python/tests/test_streaming_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Integration smoke: the generate node, when invoked with a canned chat-model
stream, dispatches a2ui-partial events and the final state has the
single-bubble shape from PR #255."""
import json
from unittest.mock import patch, AsyncMock

import pytest
from langchain_core.messages import AIMessageChunk

from src.streaming.a2ui_partial_handler import A2uiPartialHandler


def _make_canned_stream() -> list[AIMessageChunk]:
"""Five chunks of growing args for one tool_call to render_a2ui_surface."""
return [
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"envelopes":[',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"surfaceUpdate":{"surfaceId":"s","components":[{"id":"root","type":"text","props":{}}]}},',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"beginRendering":{"surfaceId":"s","root":"root"}},',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": '{"dataModelUpdate":{"surfaceId":"s","contents":[{"key":"text","valueString":"hi"}]}}',
}]),
AIMessageChunk(content="", tool_call_chunks=[{
"id": "tc-1", "name": "render_a2ui_surface", "index": 0,
"args": "]}",
}]),
]


@pytest.mark.asyncio
async def test_handler_dispatches_per_chunk():
"""At least 3 a2ui-partial events fire as the canned stream advances."""
handler = A2uiPartialHandler(tool_name="render_a2ui_surface")
with patch("src.streaming.a2ui_partial_handler.adispatch_custom_event", new=AsyncMock()) as mock:
for chunk in _make_canned_stream():
await handler.on_chat_model_stream(chunk, run_id="r1")
assert mock.await_count >= 3
# Last cumulative string is the full envelope JSON.
last = mock.await_args_list[-1].args[1]
assert last["tool_call_id"] == "tc-1"
body = json.loads(last["args_so_far"])
assert "envelopes" in body
assert len(body["envelopes"]) == 3
Loading