From d9260f305008a49ec0304d873c59f21c71891025 Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Sun, 14 Jun 2026 14:28:49 +0600 Subject: [PATCH 1/3] Add unit tests for AudioAgent, LiveAgent, and realtime events - Introduced comprehensive unit tests for the AudioAgent, ensuring proper compilation and tool integration. - Added tests for the LiveAgent's duplex loop, tool invocation, and session management. - Implemented tests for various realtime events, including audio delta, transcripts, and tool calls. - Created a test suite for the GeminiLiveClient, focusing on message normalization and client lifecycle. - Established tests for the LiveInputQueue, verifying audio and text frame handling. - Enhanced the realtime events taxonomy with new categories and content types. --- agentflow/core/graph/compiled_graph.py | 100 +++- agentflow/core/realtime/__init__.py | 51 ++ agentflow/core/realtime/base.py | 217 ++++++++ agentflow/core/realtime/live_agent.py | 462 ++++++++++++++++++ agentflow/core/realtime/providers/__init__.py | 0 .../core/realtime/providers/gemini_live.py | 292 +++++++++++ agentflow/core/realtime/queue.py | 98 ++++ agentflow/prebuilt/agent/__init__.py | 2 + agentflow/prebuilt/agent/audio.py | 125 +++++ agentflow/runtime/publisher/events.py | 2 + .../2026-06-14-realtime-audio-agent-design.md | 312 ++++++++++++ examples/realtime/README.md | 79 +++ examples/realtime/__init__.py | 0 examples/realtime/agentflow.json | 8 + examples/realtime/audio_agent_file.py | 106 ++++ examples/realtime/audio_agent_mic.py | 104 ++++ examples/realtime/graph.py | 43 ++ tests/publisher/test_events.py | 11 +- tests/realtime/__init__.py | 0 tests/realtime/test_arealtime.py | 114 +++++ tests/realtime/test_audio_agent.py | 54 ++ tests/realtime/test_base.py | 108 ++++ tests/realtime/test_gemini_live.py | 390 +++++++++++++++ tests/realtime/test_live_agent.py | 456 +++++++++++++++++ tests/realtime/test_package_exports.py | 38 ++ tests/realtime/test_queue.py | 71 +++ tests/realtime/test_realtime_events_enum.py | 19 + 27 files changed, 3257 insertions(+), 5 deletions(-) create mode 100644 agentflow/core/realtime/__init__.py create mode 100644 agentflow/core/realtime/base.py create mode 100644 agentflow/core/realtime/live_agent.py create mode 100644 agentflow/core/realtime/providers/__init__.py create mode 100644 agentflow/core/realtime/providers/gemini_live.py create mode 100644 agentflow/core/realtime/queue.py create mode 100644 agentflow/prebuilt/agent/audio.py create mode 100644 docs/superpowers/specs/2026-06-14-realtime-audio-agent-design.md create mode 100644 examples/realtime/README.md create mode 100644 examples/realtime/__init__.py create mode 100644 examples/realtime/agentflow.json create mode 100644 examples/realtime/audio_agent_file.py create mode 100644 examples/realtime/audio_agent_mic.py create mode 100644 examples/realtime/graph.py create mode 100644 tests/realtime/__init__.py create mode 100644 tests/realtime/test_arealtime.py create mode 100644 tests/realtime/test_audio_agent.py create mode 100644 tests/realtime/test_base.py create mode 100644 tests/realtime/test_gemini_live.py create mode 100644 tests/realtime/test_live_agent.py create mode 100644 tests/realtime/test_package_exports.py create mode 100644 tests/realtime/test_queue.py create mode 100644 tests/realtime/test_realtime_events_enum.py diff --git a/agentflow/core/graph/compiled_graph.py b/agentflow/core/graph/compiled_graph.py index 437146f1..9786def8 100644 --- a/agentflow/core/graph/compiled_graph.py +++ b/agentflow/core/graph/compiled_graph.py @@ -20,6 +20,7 @@ from agentflow.storage.checkpointer.base_checkpointer import BaseCheckpointer from agentflow.storage.store.base_store import BaseStore from agentflow.utils import ( + CallbackManager, ResponseGranularity, ) from agentflow.utils.background_task_manager import BackgroundTaskManager @@ -275,6 +276,7 @@ async def ainvoke( Returns: Response dict based on granularity """ + self._guard_not_realtime() cfg = self._prepare_config(config, is_stream=False) return await self._invoke_handler.invoke( @@ -465,7 +467,7 @@ async def astream( Yields: Message objects with incremental content """ - + self._guard_not_realtime() cfg = self._prepare_config(config, is_stream=True) async for chunk in self._stream_handler.stream( @@ -527,6 +529,102 @@ def attach_remote_tools( node_name, ) + # ------------------------------------------------------------------ # + # Realtime runtime (audio-to-audio). A separate runtime from the + # super-step invoke/stream loop: the live agent owns the turn loop. + # ------------------------------------------------------------------ # + def _find_live_nodes(self) -> list[tuple[str, Node]]: + from agentflow.core.realtime.live_agent import LiveAgent + + return [ + (name, node) + for name, node in self._state_graph.nodes.items() + if isinstance(node.func, LiveAgent) + ] + + def _guard_not_realtime(self) -> None: + """Forcing rule: a graph containing a LiveAgent must use arealtime().""" + if self._find_live_nodes(): + raise RuntimeError( + "This graph contains a LiveAgent; use .arealtime() / .realtime() instead of " + "invoke/ainvoke/stream/astream." + ) + + async def arealtime( + self, + input_queue: Any, + config: dict[str, Any] | None = None, + state: AgentState | None = None, + ) -> AsyncIterator[Any]: + """Run the graph's realtime (audio) session, yielding normalized RealtimeEvents. + + Forcing rule: the graph must contain exactly one LiveAgent (the root controller); + ordinary turn-based graphs must use invoke/stream. + """ + live = self._find_live_nodes() + if not live: + raise RuntimeError( + "arealtime() requires a graph rooted at a LiveAgent (e.g. AudioAgent); " + "this graph has none. Use invoke/stream for turn-based graphs." + ) + if len(live) > 1: + raise RuntimeError( + "Only one LiveAgent is allowed per realtime run in v1 " + f"(found {len(live)}: {[name for name, _ in live]})." + ) + + name, node = live[0] + agent = node.func + agent._node_name = name + cfg = self._prepare_config(config, is_stream=True) + callback_manager = InjectQ.get_instance().try_get(CallbackManager) + context_manager = self._state_graph._context_manager + run_state = state if state is not None else (self._state or AgentState()) + + async for event in agent.arun( + input_queue, + cfg, + run_state, + checkpointer=self._checkpointer, + callback_manager=callback_manager, + context_manager=context_manager, + ): + yield event + + def realtime( + self, + input_queue: Any, + config: dict[str, Any] | None = None, + state: AgentState | None = None, + ) -> Generator[Any]: + """Synchronous wrapper over :meth:`arealtime` for non-async consumers. + + Must be called from a thread with no running event loop; from inside an async + context (FastAPI handler, Jupyter), use :meth:`arealtime` directly. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + pass # no running loop: safe to drive a private one below + else: + raise RuntimeError( + "realtime() (sync) cannot be called from a running event loop; " + "await arealtime() instead." + ) + + agen = self.arealtime(input_queue, config, state) + loop = asyncio.new_event_loop() + try: + while True: + try: + yield loop.run_until_complete(agen.__anext__()) + except StopAsyncIteration: + break + finally: + with contextlib.suppress(Exception): + loop.run_until_complete(agen.aclose()) + loop.close() + async def aclose(self) -> dict[str, Any]: # noqa: PLR0915 """ Close the graph and release all resources gracefully. diff --git a/agentflow/core/realtime/__init__.py b/agentflow/core/realtime/__init__.py new file mode 100644 index 00000000..e8bf65b2 --- /dev/null +++ b/agentflow/core/realtime/__init__.py @@ -0,0 +1,51 @@ +"""Realtime (audio-to-audio) runtime primitives. + +Provider-neutral contracts (:data:`RealtimeEvent`, :class:`RealtimeConfig`, +:class:`RealtimeClient`), the upstream :class:`LiveInputQueue`, and the Gemini Live +provider client. Provider SDK imports are lazy, so importing this package never pulls +the ``realtime`` optional dependency. +""" + +from .base import ( + AgentChangedEvent, + AudioDeltaEvent, + ErrorEvent, + GoAwayEvent, + InputTranscriptEvent, + InterruptedEvent, + OutputTranscriptEvent, + RealtimeClient, + RealtimeConfig, + RealtimeEvent, + SessionUpdateEvent, + ToolCallEvent, + ToolResultEvent, + TurnCompleteEvent, + VADConfig, +) +from .providers.gemini_live import GeminiLiveClient, normalize_message +from .queue import LiveInput, LiveInputKind, LiveInputQueue + + +__all__ = [ + "AgentChangedEvent", + "AudioDeltaEvent", + "ErrorEvent", + "GeminiLiveClient", + "GoAwayEvent", + "InputTranscriptEvent", + "InterruptedEvent", + "LiveInput", + "LiveInputKind", + "LiveInputQueue", + "OutputTranscriptEvent", + "RealtimeClient", + "RealtimeConfig", + "RealtimeEvent", + "SessionUpdateEvent", + "ToolCallEvent", + "ToolResultEvent", + "TurnCompleteEvent", + "VADConfig", + "normalize_message", +] diff --git a/agentflow/core/realtime/base.py b/agentflow/core/realtime/base.py new file mode 100644 index 00000000..859a7152 --- /dev/null +++ b/agentflow/core/realtime/base.py @@ -0,0 +1,217 @@ +"""Provider-neutral contracts for realtime (audio-to-audio) sessions. + +These types are the seam between Agentflow and any realtime provider (Gemini Live +first, OpenAI Realtime later). Nothing here imports a provider SDK; provider clients +live under ``agentflow.core.realtime.providers`` and normalize their wire messages +into the :data:`RealtimeEvent` union defined below. +""" + +from typing import Annotated, Any, Literal, Protocol, Union, runtime_checkable + +from pydantic import BaseModel, Field, field_validator + + +# Audio format facts for Gemini Live: input PCM16 mono @ 16kHz, output PCM16 @ 24kHz. +INPUT_SAMPLE_RATE = 16000 +OUTPUT_SAMPLE_RATE = 24000 + + +# --------------------------------------------------------------------------- # +# RealtimeEvent: the normalized event every downstream consumer reads. +# Discriminated union keyed on ``type`` (mirrors core.state ContentBlock). +# --------------------------------------------------------------------------- # +class AudioDeltaEvent(BaseModel): + """A chunk of model audio output (PCM16).""" + + type: Literal["audio_delta"] = "audio_delta" + data: bytes + sample_rate: int = OUTPUT_SAMPLE_RATE + + +class InputTranscriptEvent(BaseModel): + """Transcript of the user's speech (from the provider's input transcription).""" + + type: Literal["input_transcript"] = "input_transcript" + text: str + finished: bool = False + + +class OutputTranscriptEvent(BaseModel): + """Transcript of the model's speech (from the provider's output transcription).""" + + type: Literal["output_transcript"] = "output_transcript" + text: str + finished: bool = False + + +class ToolCallEvent(BaseModel): + """The provider is requesting a tool invocation.""" + + type: Literal["tool_call"] = "tool_call" + id: str + name: str + args: dict[str, Any] = Field(default_factory=dict) + + +class ToolResultEvent(BaseModel): + """A tool finished executing (emitted for observability after the result is sent back).""" + + type: Literal["tool_result"] = "tool_result" + id: str + result: Any = None + + +class TurnCompleteEvent(BaseModel): + """The model finished generating a turn.""" + + type: Literal["turn_complete"] = "turn_complete" + + +class InterruptedEvent(BaseModel): + """Barge-in: the user spoke over the model; the client should flush playback.""" + + type: Literal["interrupted"] = "interrupted" + + +class SessionUpdateEvent(BaseModel): + """The provider issued a session-resumption handle.""" + + type: Literal["session_update"] = "session_update" + resumption_handle: str | None = None + + +class GoAwayEvent(BaseModel): + """The provider will close the socket soon; reconnect with the resumption handle.""" + + type: Literal["go_away"] = "go_away" + # Provider duration string (e.g. Gemini "5s"); passed through verbatim. + time_left: str | None = None + + +class AgentChangedEvent(BaseModel): + """The active agent/author changed (future multi-agent persona swap).""" + + type: Literal["agent_changed"] = "agent_changed" + author: str + + +class ErrorEvent(BaseModel): + """A normalized provider error. Fatal errors close the session; transient ones continue.""" + + type: Literal["error"] = "error" + code: str | None = None + message: str + fatal: bool = False + + +RealtimeEvent = Annotated[ + Union[ + AudioDeltaEvent, + InputTranscriptEvent, + OutputTranscriptEvent, + ToolCallEvent, + ToolResultEvent, + TurnCompleteEvent, + InterruptedEvent, + SessionUpdateEvent, + GoAwayEvent, + AgentChangedEvent, + ErrorEvent, + ], + Field(discriminator="type"), +] + + +# --------------------------------------------------------------------------- # +# RealtimeConfig: per-session value object, provider-neutral. +# --------------------------------------------------------------------------- # +ResponseModality = Literal["AUDIO", "TEXT"] + + +class VADConfig(BaseModel): + """Voice-activity-detection settings. Disable for push-to-talk (manual activity).""" + + enabled: bool = True + # Provider-neutral sensitivity hint; mapped per provider. None = provider default. + start_sensitivity: str | None = None + end_sensitivity: str | None = None + prefix_padding_ms: int | None = None + silence_duration_ms: int | None = None + + +class RealtimeConfig(BaseModel): + """Per-session configuration handed to a :class:`RealtimeClient`. + + Gemini Live permits exactly one response modality per session; ``response_modalities`` + is validated to a single entry. + """ + + model: str + response_modalities: list[ResponseModality] = Field(default_factory=lambda: ["AUDIO"]) + voice: str | None = None + system_instruction: str | None = None + input_audio_transcription: bool = True + output_audio_transcription: bool = True + vad: VADConfig = Field(default_factory=VADConfig) + context_window_compression: bool = False + session_resumption: bool = True + tools: list[Any] | None = None + tools_tags: list[str] | None = None + + @field_validator("response_modalities") + @classmethod + def _exactly_one_modality(cls, value: list[ResponseModality]) -> list[ResponseModality]: + if len(value) != 1: + raise ValueError( + "response_modalities must contain exactly one modality per session " + f"(got {value!r}); a realtime session is single-modality." + ) + return value + + +# --------------------------------------------------------------------------- # +# RealtimeClient: provider Protocol. One implementation per provider. +# --------------------------------------------------------------------------- # +@runtime_checkable +class RealtimeClient(Protocol): + """Protocol every provider client implements. + + Owns a single provider WebSocket for the lifetime of a session. ``receive()`` yields + normalized :data:`RealtimeEvent`s; the send methods push input upstream. + """ + + async def connect(self, config: RealtimeConfig, resume_handle: str | None = None) -> None: + """Open the provider socket for ``config``, optionally resuming ``resume_handle``.""" + ... + + async def send_audio(self, pcm: bytes, sample_rate: int) -> None: + """Send a chunk of input audio (PCM16).""" + ... + + async def send_text(self, text: str) -> None: + """Send a text turn into the live session.""" + ... + + async def send_activity_start(self) -> None: + """Manual-VAD / push-to-talk: mark the start of user activity.""" + ... + + async def send_activity_end(self) -> None: + """Manual-VAD / push-to-talk: mark the end of user activity.""" + ... + + async def send_tool_response(self, call_id: str, name: str, result: Any) -> None: + """Return a tool result to the model for ``call_id``.""" + ... + + async def reseed_history(self, messages: list[Any]) -> None: + """Seed an existing conversation history into a fresh session (cross-session resume).""" + ... + + def receive(self): # -> AsyncIterator[RealtimeEvent] + """Async-iterate normalized events from the provider.""" + ... + + async def close(self) -> None: + """Close the provider socket. Must be safe to call more than once.""" + ... diff --git a/agentflow/core/realtime/live_agent.py b/agentflow/core/realtime/live_agent.py new file mode 100644 index 00000000..d0040512 --- /dev/null +++ b/agentflow/core/realtime/live_agent.py @@ -0,0 +1,462 @@ +"""LiveAgent -- the realtime (audio-to-audio) node and root session controller. + +``LiveAgent`` subclasses :class:`BaseAgent` and reuses the skills/memory builder mixins, +but it deliberately **excludes** the text turn loop (``AgentExecutionMixin``): realtime +inverts control (the provider owns the turn loop), so it writes its own duplex loop. + +It is entered by ``CompiledGraph.arealtime`` (Phase 3) via :meth:`arun`, which: + +1. opens one provider socket (the spine, held for the whole session), +2. runs a pump task (queue -> provider) concurrently with a receive loop, +3. dispatches tool calls through the existing :class:`ToolNode` (callbacks + publisher + events fire inside ToolNode, so transparency is identical to text mode), +4. persists finished transcripts as ``Message``s (no audio at rest), and +5. transparently reconnects on ``go_away``/drop using the cached resumption handle. + +Calling the agent through the turn-based engine (``execute``) raises -- the forcing rule. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import AsyncIterator, Callable +from typing import TYPE_CHECKING, Any + +from agentflow.core.graph.agent_internal.memory import AgentMemoryMixin +from agentflow.core.graph.agent_internal.skills import AgentSkillsMixin +from agentflow.core.graph.base_agent import BaseAgent +from agentflow.core.graph.tool_node import ToolNode +from agentflow.core.llm import detect_provider +from agentflow.core.realtime.base import RealtimeClient, RealtimeConfig, ToolResultEvent +from agentflow.core.realtime.providers.gemini_live import GeminiLiveClient +from agentflow.core.state import AgentState, Message, TextBlock, add_messages +from agentflow.runtime.publisher.events import ContentType, Event, EventModel, EventType +from agentflow.runtime.publisher.publish import publish_event +from agentflow.utils import CallbackManager + + +if TYPE_CHECKING: + from agentflow.core.realtime.base import RealtimeEvent + from agentflow.core.realtime.queue import LiveInputQueue + from agentflow.core.state import BaseContextManager + from agentflow.storage.checkpointer import BaseCheckpointer + +logger = logging.getLogger(__name__) + + +class LiveAgent(AgentSkillsMixin, AgentMemoryMixin, BaseAgent): + """Realtime audio agent node. Run via :meth:`arun` (or ``CompiledGraph.arealtime``).""" + + def __init__( + self, + model: str, + *, + realtime_config: RealtimeConfig | None = None, + system_prompt: list[dict[str, Any]] | None = None, + tool_node: str | ToolNode | None = None, + skills: Any | None = None, + memory: Any | None = None, + realtime_client_factory: Callable[[], RealtimeClient] | None = None, + **kwargs: Any, + ) -> None: + api_key: str | None = kwargs.pop("api_key", None) + use_vertex_ai: bool = kwargs.pop("use_vertex_ai", False) + + provider = detect_provider(model, use_vertex_ai) + if provider != "google": + raise ValueError( + "LiveAgent v1 supports only Gemini Live (google provider); " + f"resolved provider '{provider}' for model '{model}'." + ) + + super().__init__( + model=model, + system_prompt=system_prompt or [], + tool_node=tool_node, + **kwargs, + ) + + self.provider = "google" + self.use_vertex_ai = use_vertex_ai + self.realtime_config = realtime_config or RealtimeConfig(model=model) + + # Tool wiring (we don't use AgentExecutionMixin._setup_tools). + self._tool_node: ToolNode | None = tool_node if isinstance(tool_node, ToolNode) else None + self.tool_node_name: str | None = tool_node if isinstance(tool_node, str) else None + + # One client per *connection*; the factory lets reconnects get a fresh socket + # and lets tests inject a fake provider. + self._client_factory: Callable[[], RealtimeClient] = realtime_client_factory or ( + lambda: GeminiLiveClient(api_key=api_key, use_vertex_ai=use_vertex_ai) + ) + self._active_client: RealtimeClient | None = None + self._resume_handle: str | None = None + # Serializes upstream sends against reconnect (close+connect) so the pump never + # sends on a socket being torn down, and always picks up the reconnected client. + self._send_lock = asyncio.Lock() + + # Builder mixins (no-op when their config is None). + self._setup_memory(memory) + self._setup_skills(skills) + + # ------------------------------------------------------------------ # + # Forcing rule: a live agent is not a turn-based node. + # ------------------------------------------------------------------ # + async def execute(self, state: AgentState, config: dict[str, Any]) -> Any: + raise RuntimeError( + "LiveAgent runs via CompiledGraph.arealtime(); it is not a turn-based node. " + "Use .arealtime() (or the AudioAgent prebuilt), not invoke/stream." + ) + + async def _call_llm( + self, messages: list[dict[str, Any]], tools: list | None = None, **kwargs: Any + ) -> Any: + # Realtime never makes a discrete turn-based LLM call; the provider owns the loop. + raise RuntimeError( + "LiveAgent has no discrete LLM call; audio turns are driven by the realtime socket." + ) + + def _resolve_tool_node(self) -> ToolNode | None: + return self._tool_node + + # ------------------------------------------------------------------ # + # The duplex realtime loop. + # ------------------------------------------------------------------ # + async def arun( + self, + input_queue: LiveInputQueue, + config: dict[str, Any], + state: AgentState | None = None, + *, + checkpointer: BaseCheckpointer | None = None, + callback_manager: CallbackManager | None = None, + context_manager: BaseContextManager | None = None, + ) -> AsyncIterator[RealtimeEvent]: + """Open the session and yield normalized events until the queue/session closes.""" + state = state if state is not None else AgentState() + if callback_manager is None: + callback_manager = CallbackManager() + rt = self._session_realtime_config(config) + rt = await self._resolve_session_tools(rt) + + handle = await self._load_resume_handle(config, checkpointer) + client = self._client_factory() + await client.connect(rt, resume_handle=handle) + self._active_client = client + await self._maybe_reseed(config, checkpointer, context_manager) + + # Closing the input queue ends the session: the pump sets this when it drains the + # close sentinel, and the receive loop stops once the provider goes idle. + stop_event = asyncio.Event() + pump_task = asyncio.create_task(self._pump(input_queue, stop_event)) + try: + while True: + reconnect = False + forced = False # go_away: reconnect even after input closed, to finish the turn + try: + async for event in self._receive_until_stop(self._active_client, stop_event): + for out in await self._handle_event( + event, config, state, checkpointer, callback_manager + ): + yield out + if event.type == "go_away": + reconnect = True + forced = True + break + if event.type == "error" and getattr(event, "fatal", False): + return + except Exception: + # Transient drop: only resume if input is still open (avoid an + # infinite reconnect storm once the session is shutting down). + logger.warning("realtime receive loop error; attempting resume", exc_info=True) + reconnect = True + + if reconnect and rt.session_resumption and (forced or not stop_event.is_set()): + await self._reconnect(rt) + continue + break + finally: + pump_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await pump_task + await self._active_client.close() + + def _session_realtime_config(self, config: dict[str, Any]) -> RealtimeConfig: + """Merge per-session overrides (``config["realtime"]``) over the agent's base config. + + Lets a caller (e.g. the API init frame) pick model/voice/modalities/vad per session + without rebuilding the agent. Unknown keys are ignored; the result is re-validated. + """ + overrides = (config or {}).get("realtime") or {} + if not overrides: + return self.realtime_config + base = self.realtime_config.model_dump() + for key, value in overrides.items(): + if value is not None and key in base: + base[key] = value + return RealtimeConfig.model_validate(base) + + async def _resolve_session_tools(self, rt: RealtimeConfig) -> RealtimeConfig: + """Advertise the agent's ToolNode tools to the provider so the model can call them. + + No-op when tools were set explicitly on the config (the caller wins) or when there + is no ToolNode. Tools are emitted as provider-neutral OpenAI-style dicts (the same + shape the turn-based path uses); the provider client converts them. ``rt.tools_tags`` + filters which tools are advertised. + """ + if rt.tools is not None: + return rt + tool_node = self._resolve_tool_node() + if tool_node is None: + return rt + tags = set(rt.tools_tags) if rt.tools_tags else None + schemas = await tool_node.all_tools(tags=tags) + if not schemas: + return rt + return rt.model_copy(update={"tools": schemas}) + + async def _receive_until_stop( + self, client: RealtimeClient, stop_event: asyncio.Event + ) -> AsyncIterator[RealtimeEvent]: + """Yield provider events, but return when ``stop_event`` fires *and* the provider + is idle. Already-available events are always drained first (a closed input queue + must not truncate the model's in-flight response).""" + receiver = client.receive().__aiter__() + stop_task = asyncio.ensure_future(stop_event.wait()) + try: + while True: + next_task = asyncio.ensure_future(receiver.__anext__()) + done, _ = await asyncio.wait( + {next_task, stop_task}, return_when=asyncio.FIRST_COMPLETED + ) + if next_task in done: + try: + yield next_task.result() + continue + except StopAsyncIteration: + return + # Provider idle and stop requested: abandon the pending receive and end. + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await next_task + return + finally: + stop_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await stop_task + + # ------------------------------------------------------------------ # + # Pump: upstream queue -> provider socket. + # ------------------------------------------------------------------ # + async def _pump( + self, input_queue: LiveInputQueue, stop_event: asyncio.Event | None = None + ) -> None: + try: + async for item in input_queue: + # Hold the lock across the send so a concurrent reconnect can't swap the + # socket mid-send; re-read the client *inside* the lock to use the live one. + async with self._send_lock: + client = self._active_client + if client is None: + continue + try: + if item.kind == "audio" and item.data is not None: + await client.send_audio(item.data, item.sample_rate) + elif item.kind == "text" and item.text is not None: + await client.send_text(item.text) + elif item.kind == "activity_start": + await client.send_activity_start() + elif item.kind == "activity_end": + await client.send_activity_end() + except Exception: + logger.warning( + "realtime pump send failed for %s frame", item.kind, exc_info=True + ) + finally: + # Input is exhausted (queue closed): let the receive loop end once idle. + if stop_event is not None: + stop_event.set() + + # ------------------------------------------------------------------ # + # Per-event handling. Returns the caller-facing events to yield. + # ------------------------------------------------------------------ # + async def _handle_event( + self, + event: RealtimeEvent, + config: dict[str, Any], + state: AgentState, + checkpointer: BaseCheckpointer | None, + callback_manager: CallbackManager, + ) -> list[RealtimeEvent]: + kind = event.type + + if kind == "tool_call": + return await self._run_tool(event, config, state, callback_manager) + + if kind == "input_transcript" and event.finished: + await self._persist_transcript(event.text, "user", config, state, checkpointer) + self._publish_realtime( + EventType.RESULT, config, ContentType.TRANSCRIPT, "input_transcript" + ) + elif kind == "output_transcript" and event.finished: + await self._persist_transcript(event.text, "assistant", config, state, checkpointer) + self._publish_realtime( + EventType.RESULT, config, ContentType.TRANSCRIPT, "output_transcript" + ) + elif kind == "interrupted": + self._publish_realtime(EventType.INTERRUPTED, config, ContentType.AUDIO, "barge_in") + elif kind == "go_away": + self._publish_realtime(EventType.UPDATE, config, ContentType.UPDATE, "go_away") + elif kind == "session_update": + self._resume_handle = event.resumption_handle + await self._persist_handle(config, checkpointer) + self._publish_realtime(EventType.UPDATE, config, ContentType.UPDATE, "session_resumed") + + return [event] + + async def _run_tool( + self, + event: RealtimeEvent, + config: dict[str, Any], + state: AgentState, + callback_manager: CallbackManager, + ) -> list[RealtimeEvent]: + tool_node = self._resolve_tool_node() + if tool_node is None: + result: Any = {"error": f"no tools registered for '{event.name}'"} + else: + invoked = await tool_node.invoke( + event.name, + event.args, + event.id, + config, + state, + callback_manager=callback_manager, + ) + result = self._extract_tool_result(invoked) + + # Socket stays open; feed the result back to the model. + await self._active_client.send_tool_response(event.id, event.name, result) + return [event, ToolResultEvent(id=event.id, result=result)] + + @staticmethod + def _extract_tool_result(invoked: Any) -> dict[str, Any]: + if isinstance(invoked, Message): + for block in invoked.content: + if getattr(block, "type", None) == "tool_result": + return {"result": getattr(block, "output", None)} + return {"result": None} + if isinstance(invoked, dict): + return invoked + return {"result": invoked} + + # ------------------------------------------------------------------ # + # Transcript persistence (Message only; audio is never stored at rest). + # ------------------------------------------------------------------ # + async def _persist_transcript( + self, + text: str, + role: str, + config: dict[str, Any], + state: AgentState, + checkpointer: BaseCheckpointer | None, + ) -> None: + msg = Message(role=role, content=[TextBlock(text=text)], metadata={"modality": "audio"}) + state.context = add_messages(state.context, [msg]) + if checkpointer is not None: + await checkpointer.aput_messages(config, [msg]) + + # ------------------------------------------------------------------ # + # Resumption: within-session reconnect + cross-session reseed. + # ------------------------------------------------------------------ # + async def _load_resume_handle( + self, config: dict[str, Any], checkpointer: BaseCheckpointer | None + ) -> str | None: + if checkpointer is None or not self.realtime_config.session_resumption: + return None + try: + thread = await checkpointer.aget_thread(config) + except Exception: + return None + if thread and thread.metadata: + handle = thread.metadata.get("resumption_handle") + self._resume_handle = handle + return handle + return None + + async def _persist_handle( + self, config: dict[str, Any], checkpointer: BaseCheckpointer | None + ) -> None: + if checkpointer is None: + return + from agentflow.utils.thread_info import ThreadInfo + + try: + thread = await checkpointer.aget_thread(config) + except Exception: + thread = None + metadata = dict(thread.metadata or {}) if thread else {} + metadata["resumption_handle"] = self._resume_handle + info = ThreadInfo( + thread_id=config.get("thread_id", ""), + user_id=config.get("user_id"), + metadata=metadata, + ) + await checkpointer.aput_thread(config, info) + + async def _maybe_reseed( + self, + config: dict[str, Any], + checkpointer: BaseCheckpointer | None, + context_manager: BaseContextManager | None, + ) -> None: + if checkpointer is None: + return + try: + history = await checkpointer.alist_messages(config) + except Exception: + history = None + if not history: + return + if context_manager is not None: + try: + trimmed = await context_manager.atrim_context(AgentState(context=list(history))) + history = trimmed.context + except Exception: + logger.warning("context compression failed during reseed; using raw history") + if history: + await self._active_client.reseed_history(list(history)) + + async def _reconnect(self, rt: RealtimeConfig) -> None: + async with self._send_lock: + old = self._active_client + with contextlib.suppress(Exception): + if old is not None: + await old.close() + client = self._client_factory() + await client.connect(rt, resume_handle=self._resume_handle) + self._active_client = client + + # ------------------------------------------------------------------ # + # Observability for events ToolNode doesn't already publish. + # ------------------------------------------------------------------ # + def _publish_realtime( + self, + event_type: EventType, + config: dict[str, Any], + content_type: ContentType, + lifecycle: str, + ) -> None: + publish_event( + EventModel.default( + config, + data={}, + content_type=[content_type], + event=Event.REALTIME, + event_type=event_type, + node_name=getattr(self, "_node_name", "LIVE"), + extra={"lifecycle": lifecycle, "modality": "audio"}, + ) + ) diff --git a/agentflow/core/realtime/providers/__init__.py b/agentflow/core/realtime/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentflow/core/realtime/providers/gemini_live.py b/agentflow/core/realtime/providers/gemini_live.py new file mode 100644 index 00000000..46ab168c --- /dev/null +++ b/agentflow/core/realtime/providers/gemini_live.py @@ -0,0 +1,292 @@ +"""Gemini Live provider client. + +Wraps ``client.aio.live.connect(...)`` (an async context manager yielding a live +session) behind the provider-neutral :class:`~agentflow.core.realtime.base.RealtimeClient` +protocol. ``normalize_message`` maps a google ``LiveServerMessage`` to the framework's +:data:`RealtimeEvent` union; it is duck-typed and imports no provider SDK, so it is unit +testable with lightweight stand-ins. +""" + +from __future__ import annotations + +import logging +import re +from collections.abc import AsyncIterator +from typing import Any +from uuid import uuid4 + +from agentflow.core.realtime.base import ( + INPUT_SAMPLE_RATE, + OUTPUT_SAMPLE_RATE, + AudioDeltaEvent, + GoAwayEvent, + InputTranscriptEvent, + InterruptedEvent, + OutputTranscriptEvent, + RealtimeConfig, + RealtimeEvent, + SessionUpdateEvent, + ToolCallEvent, + TurnCompleteEvent, +) + + +logger = logging.getLogger(__name__) + +_RATE_RE = re.compile(r"rate=(\d+)") + + +def _rate_from_mime(mime_type: str | None, default: int) -> int: + """Extract the sample rate from a ``audio/pcm;rate=24000`` mime string.""" + if not mime_type: + return default + match = _RATE_RE.search(mime_type) + return int(match.group(1)) if match else default + + +def normalize_message(message: Any) -> list[RealtimeEvent]: + """Map a google ``LiveServerMessage`` to zero or more normalized events. + + Reads attributes defensively (``getattr``) so it tolerates both the real SDK objects + and test stand-ins, and emits events in wire order within a single message. + """ + events: list[RealtimeEvent] = [] + + content = getattr(message, "server_content", None) + if content is not None: + model_turn = getattr(content, "model_turn", None) + for part in getattr(model_turn, "parts", None) or []: + inline = getattr(part, "inline_data", None) + data = getattr(inline, "data", None) if inline is not None else None + if data: + rate = _rate_from_mime(getattr(inline, "mime_type", None), OUTPUT_SAMPLE_RATE) + events.append(AudioDeltaEvent(data=data, sample_rate=rate)) + + in_tx = getattr(content, "input_transcription", None) + if in_tx is not None and getattr(in_tx, "text", None) is not None: + events.append( + InputTranscriptEvent( + text=in_tx.text, finished=bool(getattr(in_tx, "finished", False)) + ) + ) + + out_tx = getattr(content, "output_transcription", None) + if out_tx is not None and getattr(out_tx, "text", None) is not None: + events.append( + OutputTranscriptEvent( + text=out_tx.text, finished=bool(getattr(out_tx, "finished", False)) + ) + ) + + if getattr(content, "interrupted", None): + events.append(InterruptedEvent()) + + if getattr(content, "generation_complete", None) or getattr(content, "turn_complete", None): + events.append(TurnCompleteEvent()) + + tool_call = getattr(message, "tool_call", None) + if tool_call is not None: + for fc in getattr(tool_call, "function_calls", None) or []: + events.append( + ToolCallEvent( + id=getattr(fc, "id", None) or uuid4().hex, + name=getattr(fc, "name", "") or "", + args=getattr(fc, "args", None) or {}, + ) + ) + + update = getattr(message, "session_resumption_update", None) + if update is not None: + events.append(SessionUpdateEvent(resumption_handle=getattr(update, "new_handle", None))) + + go_away = getattr(message, "go_away", None) + if go_away is not None: + events.append(GoAwayEvent(time_left=getattr(go_away, "time_left", None))) + + return events + + +class GeminiLiveClient: + """``RealtimeClient`` implementation backed by the Gemini Live API. + + ``connector`` is the seam for testing/overrides: a callable + ``(model=..., config=...) -> async context manager`` that yields a live session. + In production it defaults to ``genai.Client(...).aio.live.connect``. + """ + + def __init__( + self, + client: Any | None = None, + *, + connector: Any | None = None, + api_key: str | None = None, + use_vertex_ai: bool = False, + ) -> None: + self._client = client + self._connector = connector + self._api_key = api_key + self._use_vertex_ai = use_vertex_ai + self._config: RealtimeConfig | None = None + self._cm: Any | None = None + self._session: Any | None = None + + @property + def connected(self) -> bool: + return self._session is not None + + # --- lazy provider construction (guarded optional dependency) --------- # + @staticmethod + def _genai(): + try: + from google import genai + from google.genai import types + except ImportError as exc: # pragma: no cover - exercised only without the extra + raise ImportError( + "google-genai SDK is required for Gemini realtime. " + "Install it with: pip install 10xscale-agentflow[realtime]" + ) from exc + return genai, types + + def _ensure_client(self) -> Any: + if self._client is None: + genai, _ = self._genai() + self._client = genai.Client(api_key=self._api_key, vertexai=self._use_vertex_ai) + return self._client + + def _get_connector(self) -> Any: + if self._connector is not None: + return self._connector + return self._ensure_client().aio.live.connect + + def _build_connect_config( + self, config: RealtimeConfig, resume_handle: str | None = None + ) -> Any: + _, types = self._genai() + kwargs: dict[str, Any] = { + "response_modalities": [types.Modality(m) for m in config.response_modalities], + } + if config.system_instruction: + kwargs["system_instruction"] = config.system_instruction + if config.voice: + kwargs["speech_config"] = types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=config.voice) + ) + ) + if config.input_audio_transcription: + kwargs["input_audio_transcription"] = types.AudioTranscriptionConfig() + if config.output_audio_transcription: + kwargs["output_audio_transcription"] = types.AudioTranscriptionConfig() + if config.session_resumption or resume_handle: + kwargs["session_resumption"] = types.SessionResumptionConfig(handle=resume_handle) + if config.context_window_compression: + kwargs["context_window_compression"] = types.ContextWindowCompressionConfig( + sliding_window=types.SlidingWindow() + ) + if not config.vad.enabled: + kwargs["realtime_input_config"] = types.RealtimeInputConfig( + automatic_activity_detection=types.AutomaticActivityDetection(disabled=True) + ) + if config.tools: + kwargs["tools"] = self._to_provider_tools(types, config.tools) + return types.LiveConnectConfig(**kwargs) + + @staticmethod + def _to_provider_tools(types: Any, tools: list[Any]) -> list[Any]: + """Convert provider-neutral OpenAI-style tool dicts into Gemini tool objects. + + OpenAI-format dicts (``{"type":"function","function":{...}}``) are collected into a + single ``types.Tool(function_declarations=[...])``. Anything else (raw callables, + already-built ``types.Tool``/``FunctionDeclaration``) passes through untouched. + """ + declarations: list[Any] = [] + passthrough: list[Any] = [] + for entry in tools: + if isinstance(entry, dict) and "function" in entry: + fn = entry["function"] + decl_kwargs: dict[str, Any] = { + "name": fn["name"], + "description": fn.get("description", ""), + } + if fn.get("parameters") is not None: + decl_kwargs["parameters_json_schema"] = fn["parameters"] + declarations.append(types.FunctionDeclaration(**decl_kwargs)) + else: + passthrough.append(entry) + result = list(passthrough) + if declarations: + result.append(types.Tool(function_declarations=declarations)) + return result + + # --- RealtimeClient protocol ----------------------------------------- # + async def connect(self, config: RealtimeConfig, resume_handle: str | None = None) -> None: + self._config = config + connector = self._get_connector() + live_config = self._build_connect_config(config, resume_handle=resume_handle) + self._cm = connector(model=config.model, config=live_config) + self._session = await self._cm.__aenter__() + + def _require_session(self) -> Any: + if self._session is None: + raise RuntimeError("GeminiLiveClient is not connected; call connect() first") + return self._session + + async def send_audio(self, pcm: bytes, sample_rate: int = INPUT_SAMPLE_RATE) -> None: + session = self._require_session() + _, types = self._genai() + await session.send_realtime_input( + audio=types.Blob(data=pcm, mime_type=f"audio/pcm;rate={sample_rate}") + ) + + async def send_text(self, text: str) -> None: + session = self._require_session() + await session.send_realtime_input(text=text) + + async def send_activity_start(self) -> None: + session = self._require_session() + _, types = self._genai() + await session.send_realtime_input(activity_start=types.ActivityStart()) + + async def send_activity_end(self) -> None: + session = self._require_session() + _, types = self._genai() + await session.send_realtime_input(activity_end=types.ActivityEnd()) + + async def send_tool_response(self, call_id: str, name: str, result: Any) -> None: + session = self._require_session() + _, types = self._genai() + await session.send_tool_response( + function_responses=[types.FunctionResponse(id=call_id, name=name, response=result)] + ) + + async def reseed_history(self, messages: list[Any]) -> None: + session = self._require_session() + _, types = self._genai() + turns = [] + for message in messages: + text = "".join( + getattr(block, "text", "") or "" for block in getattr(message, "content", []) or [] + ) + if not text: + continue + role = "model" if getattr(message, "role", "user") == "assistant" else "user" + turns.append(types.Content(role=role, parts=[types.Part.from_text(text=text)])) + if turns: + await session.send_client_content(turns=turns, turn_complete=True) + + async def receive(self) -> AsyncIterator[RealtimeEvent]: + session = self._require_session() + async for message in session.receive(): + for event in normalize_message(message): + yield event + + async def close(self) -> None: + cm = self._cm + if cm is None: + return + self._cm = None + self._session = None + try: + await cm.__aexit__(None, None, None) + except Exception: # pragma: no cover - best-effort teardown + logger.warning("Error while closing Gemini live session", exc_info=True) diff --git a/agentflow/core/realtime/queue.py b/agentflow/core/realtime/queue.py new file mode 100644 index 00000000..efd0fff0 --- /dev/null +++ b/agentflow/core/realtime/queue.py @@ -0,0 +1,98 @@ +"""Upstream input decoupler for realtime sessions. + +``LiveInputQueue`` wraps an ``asyncio.Queue`` of ``LiveInput`` transport frames. ``put`` +is synchronous and non-blocking (``put_nowait``) so the input side keeps accepting audio +while the model is still generating -- the precondition for barge-in. A fresh queue is +created per session; it is the object an SDK user (or the API bridge) feeds. + +``LiveInput`` is a deliberately lightweight ``dataclass`` (not a ``Message`` and not a +pydantic model): these are ephemeral control frames headed for the provider socket, never +persisted and produced on the audio hot path (~50/sec). Conversation state and the +checkpointer are driven separately, from provider *transcripts* turned into ``Message``s +(see design section 7) -- not from this queue. +""" + +import asyncio +import logging +from collections.abc import AsyncIterator +from dataclasses import dataclass +from typing import Literal + +from agentflow.core.realtime.base import INPUT_SAMPLE_RATE + + +logger = logging.getLogger(__name__) + +LiveInputKind = Literal["audio", "text", "activity_start", "activity_end", "close"] + + +@dataclass(slots=True) +class LiveInput: + """A single upstream transport frame. Construct via ``LiveInputQueue.send_*``. + + ``kind`` discriminates the frame; only the fields relevant to that kind are set + (``data``/``sample_rate`` for audio, ``text`` for text, neither for control frames). + """ + + kind: LiveInputKind + data: bytes | None = None + text: str | None = None + sample_rate: int = INPUT_SAMPLE_RATE + + +class LiveInputQueue: + """A non-blocking, single-session input queue feeding the realtime pump task. + + Producers call the synchronous ``send_*`` / ``close`` methods from any context + (e.g. an audio callback). The pump task consumes via ``get`` / ``async for``. + Once closed, further ``put``s are dropped (logged at debug), never raised. + """ + + def __init__(self, maxsize: int = 0) -> None: + self._queue: asyncio.Queue[LiveInput] = asyncio.Queue(maxsize=maxsize) + self._closed = False + + @property + def closed(self) -> bool: + return self._closed + + def _put(self, item: LiveInput) -> None: + if self._closed and item.kind != "close": + logger.debug("LiveInputQueue is closed; dropping %s frame", item.kind) + return + try: + self._queue.put_nowait(item) + except asyncio.QueueFull: + logger.warning("LiveInputQueue full; dropping %s frame", item.kind) + + def send_audio(self, data: bytes, sample_rate: int = INPUT_SAMPLE_RATE) -> None: + self._put(LiveInput(kind="audio", data=data, sample_rate=sample_rate)) + + def send_text(self, text: str) -> None: + self._put(LiveInput(kind="text", text=text)) + + def send_activity_start(self) -> None: + self._put(LiveInput(kind="activity_start")) + + def send_activity_end(self) -> None: + self._put(LiveInput(kind="activity_end")) + + def close(self) -> None: + """Signal end of input. Idempotent; enqueues a single ``close`` sentinel frame.""" + if self._closed: + return + self._put(LiveInput(kind="close")) + self._closed = True + + async def get(self) -> LiveInput: + return await self._queue.get() + + def get_nowait(self) -> LiveInput: + return self._queue.get_nowait() + + async def __aiter__(self) -> AsyncIterator[LiveInput]: + while True: + item = await self._queue.get() + if item.kind == "close": + return + yield item diff --git a/agentflow/prebuilt/agent/__init__.py b/agentflow/prebuilt/agent/__init__.py index ee652755..c54196c0 100644 --- a/agentflow/prebuilt/agent/__init__.py +++ b/agentflow/prebuilt/agent/__init__.py @@ -1,3 +1,4 @@ +from .audio import AudioAgent from .plan_act_reflect import PlanActReflectAgent from .rag import BaseReranker, CohereReranker, CrossEncoderReranker, RAGAgent from .react import ReactAgent @@ -7,6 +8,7 @@ __all__ = [ + "AudioAgent", "BaseReranker", "CohereReranker", "CrossEncoderReranker", diff --git a/agentflow/prebuilt/agent/audio.py b/agentflow/prebuilt/agent/audio.py new file mode 100644 index 00000000..5ea30d1d --- /dev/null +++ b/agentflow/prebuilt/agent/audio.py @@ -0,0 +1,125 @@ +"""AudioAgent -- prebuilt realtime (audio-to-audio) agent, React-style builder. + +Mirrors :class:`~agentflow.prebuilt.agent.react.ReactAgent`'s construction surface but +wraps a :class:`~agentflow.core.realtime.live_agent.LiveAgent` as the graph root. The +compiled graph is driven by ``CompiledGraph.arealtime`` (a separate runtime), not +``invoke``/``stream``. No sub-agents / handoff are wired in v1 (a handoff tool is just a +tool, so the door stays open). +""" + +from collections.abc import Callable, Iterable +from typing import Any + +from agentflow.core.graph.compiled_graph import CompiledGraph +from agentflow.core.graph.state_graph import StateGraph +from agentflow.core.graph.tool_node import ToolNode +from agentflow.core.realtime.base import RealtimeClient, RealtimeConfig +from agentflow.core.realtime.live_agent import LiveAgent +from agentflow.core.state.agent_state import AgentState +from agentflow.core.state.base_context import BaseContextManager +from agentflow.runtime.publisher.base_publisher import BasePublisher +from agentflow.storage.checkpointer.base_checkpointer import BaseCheckpointer +from agentflow.storage.media.storage.base import BaseMediaStore +from agentflow.storage.store.base_store import BaseStore +from agentflow.utils.callbacks import CallbackManager +from agentflow.utils.constants import END +from agentflow.utils.id_generator import BaseIDGenerator, DefaultIDGenerator + + +class AudioAgent[StateT: AgentState]: + """Build and compile a single realtime audio agent graph.""" + + def __init__( # noqa: PLR0913 + self, + model: str, + state: StateT | None = None, + context_manager: BaseContextManager[StateT] | None = None, + publisher: BasePublisher | list[BasePublisher] | None = None, + id_generator: BaseIDGenerator = DefaultIDGenerator(), + container: Any | None = None, + *, + realtime_config: RealtimeConfig | None = None, + system_prompt: list[dict[str, Any]] | None = None, + tools: Iterable[Callable] | None = None, + client: Any = None, + pass_user_info_to_mcp: bool = False, + skills: Any | None = None, + memory: Any | None = None, + realtime_client_factory: Callable[[], RealtimeClient] | None = None, + live_node_name: str = "LIVE", + **agent_kwargs: Any, + ) -> None: + self._state = state + self._context_manager = context_manager + self._publisher = publisher + self._id_generator = id_generator + self._container = container + self._live_node_name = live_node_name + + self._tool_node = self._build_tool_node( + tools=list(tools or []), + client=client, + pass_user_info_to_mcp=pass_user_info_to_mcp, + ) + + self._agent = LiveAgent( + model, + realtime_config=realtime_config, + system_prompt=system_prompt, + tool_node=self._tool_node, + skills=skills, + memory=memory, + realtime_client_factory=realtime_client_factory, + **agent_kwargs, + ) + self._graph: StateGraph[StateT] | None = None + + @staticmethod + def _build_tool_node( + *, tools: list[Callable], client: Any, pass_user_info_to_mcp: bool + ) -> ToolNode | None: + if not tools and client is None: + return None + return ToolNode(tools, client=client, pass_user_info_to_mcp=pass_user_info_to_mcp) + + def _create_graph(self) -> StateGraph[StateT]: + return StateGraph[StateT]( + state=self._state, + context_manager=self._context_manager, + publisher=self._publisher, + id_generator=self._id_generator, + container=self._container, + ) + + def _configure_graph(self) -> None: + self._graph = self._create_graph() + self._graph.add_node(self._live_node_name, self._agent) + self._graph.set_entry_point(self._live_node_name) + # The edge is never traversed in realtime (the live node owns the loop); it exists + # only so the graph is well-formed for compile(). + self._graph.add_edge(self._live_node_name, END) + + def compile( + self, + checkpointer: BaseCheckpointer[StateT] | None = None, + store: BaseStore | None = None, + interrupt_before: list[str] | None = None, + interrupt_after: list[str] | None = None, + callback_manager: CallbackManager | None = None, + media_store: BaseMediaStore | None = None, + shutdown_timeout: float = 30.0, + ) -> CompiledGraph: + self._configure_graph() + if self._graph is None: # pragma: no cover - _configure_graph always assigns + raise RuntimeError("graph configuration failed") + return self._graph.compile( + checkpointer=checkpointer, + store=store, + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, + callback_manager=callback_manager + if callback_manager is not None + else CallbackManager(), + media_store=media_store, + shutdown_timeout=shutdown_timeout, + ) diff --git a/agentflow/runtime/publisher/events.py b/agentflow/runtime/publisher/events.py index 31aa2ab6..62cfecc9 100644 --- a/agentflow/runtime/publisher/events.py +++ b/agentflow/runtime/publisher/events.py @@ -36,6 +36,7 @@ class Event(StrEnum): LLM_CALL = "llm_call" TOOL_EXECUTION = "tool_execution" STREAMING = "streaming" + REALTIME = "realtime" class EventType(StrEnum): @@ -86,6 +87,7 @@ class ContentType(StrEnum): TOOL_RESULT = "tool_result" IMAGE = "image" AUDIO = "audio" + TRANSCRIPT = "transcript" VIDEO = "video" DOCUMENT = "document" DATA = "data" diff --git a/docs/superpowers/specs/2026-06-14-realtime-audio-agent-design.md b/docs/superpowers/specs/2026-06-14-realtime-audio-agent-design.md new file mode 100644 index 00000000..432d43f5 --- /dev/null +++ b/docs/superpowers/specs/2026-06-14-realtime-audio-agent-design.md @@ -0,0 +1,312 @@ +# Realtime Audio-to-Audio Agent Support — Design + +Date: 2026-06-14 +Status: Approved design (pending spec review) + +## Scope of v1 + +- Provider: **Gemini Live API only**. OpenAI Realtime plugs into the same `RealtimeClient` seam later. +- Agent: single prebuilt `AudioAgent` (React-style). **No sub-agents / multi-agent handoff in v1.** +- Surfaces: **Python SDK is the primary, standalone surface.** The HTTP/WebSocket API is **optional + but recommended** — a thin convenience layer over the SDK, not required to use realtime. +- Auth (when the API is used): proxy-through-server. +- Persistence: transcripts only (no audio stored at rest). + +## 1. Problem + +Agentflow supports text and image through a turn-based graph engine (`CompiledGraph.invoke/astream` +yielding final results / `StreamChunk`). There is no realtime audio-to-audio (speech in, speech out). +We add bidirectional realtime voice, backed by Gemini Live, as a first-class part of the **same** +framework — same build/compile/tools/state/checkpointer/publisher — so developers learn one structure +and only the call differs. + +## 2. Why realtime is a separate runtime, not the invoke/stream loop + +The graph's `invoke`/`stream` engine traverses nodes super-step by super-step: **we** decide the next +node after each LLM call. Realtime inverts control — the **provider** owns the turn loop, decides turn +boundaries via VAD, can chain tool calls itself, and streams audio continuously with barge-in. There +is no per-turn control point for the executor to traverse edges. + +Therefore realtime is a **separate runtime** on the compiled graph: `arealtime()`. In it, the live +(audio) agent is the **root controller** that holds one provider WebSocket open for the whole session. +Tool calls and any sub-node/sub-agent invocations happen **inside** that live block; the socket is +never torn down to run them. Other nodes are callable units the live session dispatches — they do +**not** run in the invoke/stream loop. + +This keeps one framework (not two engines) while respecting who actually owns the turn loop. + +## 3. Layering (SDK-first) + +``` +[Python SDK — primary, standalone] + AudioAgent (prebuilt) -> LiveAgent (node) -> CompiledGraph.arealtime(queue, config) + developer feeds an audio input queue, consumes RealtimeEvent async iterator + | + | (uses) agentflow/core/realtime/ + | - LiveInputQueue (upstream decoupler) + | - RealtimeClient (provider-neutral protocol) + | - GeminiLiveClient (wraps google-genai client.aio.live) + v + Google Gemini Live API (one WebSocket per session = the spine) + +[HTTP/WS API — optional, recommended] + agentflow-api: /v1/graph/live bridges a client WebSocket <-> LiveInputQueue <-> arealtime() + adds auth, transport, browser access. NOT required: the SDK runs realtime by itself. +``` + +Hard boundary: **core never imports FastAPI.** `LiveAgent` owns only the **provider** socket. The +**client** WebSocket (when the API is used) lives in `agentflow-api` and bridges to `arealtime()` +through the queue. A pure-Python user supplies their own audio source (mic, file, PThread) into the +queue and consumes events directly — no server involved. + +## 4. Core components (`agentflow/core/realtime/`) + +### 4.1 `base.py` — provider-neutral contracts + +`RealtimeEvent` — normalized event everything downstream consumes (discriminated union): + +| type | payload | meaning | +|---|---|---| +| `audio_delta` | `bytes` PCM16, `sample_rate` | model audio out | +| `input_transcript` | `text`, `finished` | user speech transcript | +| `output_transcript` | `text`, `finished` | model speech transcript | +| `tool_call` | `id`, `name`, `args` | provider requests a tool | +| `tool_result` | `id`, `result` | tool finished (observability) | +| `turn_complete` | — | model finished a turn | +| `interrupted` | — | barge-in; client flushes playback | +| `session_update` | `resumption_handle` | provider resume token | +| `go_away` | `time_left` | provider will close socket soon | +| `agent_changed` | `author` | active agent changed (future multi-agent) | +| `error` | `code`, `message` | provider error | + +`RealtimeClient` — Protocol, one impl per provider: + +```python +class RealtimeClient(Protocol): + async def connect(self, config: RealtimeConfig) -> None: ... + async def send_audio(self, pcm: bytes, sample_rate: int) -> None: ... + async def send_text(self, text: str) -> None: ... + async def send_activity_start(self) -> None: ... # manual VAD / push-to-talk + async def send_activity_end(self) -> None: ... + async def send_tool_response(self, call_id: str, name: str, result: Any) -> None: ... + def receive(self) -> AsyncIterator[RealtimeEvent]: ... + async def close(self) -> None: ... +``` + +`RealtimeConfig` — per-session value object: `model`, `response_modalities` (single: `AUDIO`|`TEXT`), +`voice`, `system_instruction`, `input_audio_transcription`, `output_audio_transcription`, `vad` +(auto + sensitivity, or disabled for push-to-talk), `context_window_compression`, +`session_resumption`, `tools`, `tools_tags`. + +### 4.2 `providers/gemini_live.py` — GeminiLiveClient + +Wraps `client.aio.live.connect(model=..., config=types.LiveConnectConfig(...))` as an async context +manager. Mapping: + +- `send_audio` -> `session.send_realtime_input(audio=types.Blob(data=..., mime_type="audio/pcm;rate=16000"))` +- `send_activity_start/end` -> `session.send_realtime_input(activity_start=ActivityStart())` / `activity_end` +- `send_tool_response` -> `session.send_tool_response(function_responses=[types.FunctionResponse(...)])` +- `receive()` maps `LiveServerMessage` -> `RealtimeEvent`: + - `server_content.model_turn.parts[].inline_data` -> `audio_delta` (24kHz) + - `server_content.input_transcription` / `output_transcription` -> transcripts + - `tool_call.function_calls[]` -> `tool_call` + - `server_content.interrupted` -> `interrupted` + - `server_content.generation_complete` -> `turn_complete` + - `session_resumption_update` -> `session_update`; `go_away` -> `go_away` + +Audio facts: input PCM16 16kHz mono; output PCM16 24kHz. Client/SDK user resamples mic to 16kHz. + +### 4.3 `queue.py` — LiveInputQueue + +Upstream decoupler. Thin wrapper over `asyncio.Queue` of a `LiveInput` union +(`audio`|`text`|`activity_start`|`activity_end`|`close`). Synchronous non-blocking `put` (`put_nowait`). +Fresh queue per session. Lets the input side keep accepting audio while the model is still generating +— the precondition for barge-in. This is the object an SDK user (or the API bridge) feeds. + +### 4.4 `LiveAgent` — the realtime node and root controller + +`LiveAgent` subclasses the **base agent** (`BaseAgent`), reuses the **config/builder** mixins +(`AgentSkillsMixin`, `AgentMemoryMixin`, `AgentProviderMixin`, the tool-declaration/function-schema +builder, `convert_messages`), and **excludes** `AgentExecutionMixin` (the text turn loop) — it writes +its own duplex loop. It is a valid graph node (registerable via `add_node`) and the constructor +surface mirrors `Agent`, so tools, `container` (InjectQ), state, skills, memory, callbacks all pass +through identically to `ReactAgent`. + +Behavior when entered (by `arealtime()`): it is the **root controller** for the session. + +1. Opens the provider WebSocket (the spine; held for the whole session). +2. Runs two concurrent tasks over the `LiveInputQueue`: + - **pump task**: drains the queue -> provider (`send_audio`/`send_text`/activity). + - **receive loop** (the generator body): iterates `client.receive()` and per event: + - `audio_delta`/transcripts/`turn_complete`/`interrupted` -> yield to caller; `interrupted` + also fires a publisher/callback event. + - `tool_call` -> `ToolNode.invoke(name, args, config, state)` **internally** (existing parallel + exec, InjectQ deps, MCP); then `send_tool_response`. Socket stays open. Callbacks + publisher + fire from inside `ToolNode` (see §5). + - **route to another agent/node** (when present; future) -> invoke it as a callable unit, await + result, feed it back into the socket as content/tool-response. **Socket not torn down.** + - transcript `finished=True` -> append a `Message` to state via reducers; persist (§7). + - `session_update` -> cache resumption handle, persist to thread metadata. + - `go_away`/drop -> transparent reconnect (§8). +3. Joined with `asyncio.gather(..., return_exceptions=True)`; `close()` mandatory in `finally` to + avoid leaking provider sessions against quota. + +`on_node_start/end` callbacks fire once for the whole `LiveAgent` run. `on_llm_*` has no discrete +call; map to per-turn (`turn_complete`) or skip. + +### 4.5 `AudioAgent` (prebuilt) — single agent, React-style + +`prebuilt/agent/audio.py`. Mirrors `ReactAgent`'s builder signature (model, tools, container, state, +skills, memory, publisher, checkpointer, callbacks, system_prompt, `RealtimeConfig`), wraps +`LiveAgent` as the graph root. **No sub-agents/handoff wired in v1** (gated off; a handoff tool is +just a tool, so the door stays open). This is what users instantiate. + +## 5. `CompiledGraph.arealtime` — new runtime + transparency + +New methods alongside `invoke`/`ainvoke`/`stream`/`astream`: + +```python +async def arealtime(self, input_queue: LiveInputQueue, + config: dict) -> AsyncIterator[RealtimeEvent]: ... +def realtime(self, input_queue, config): ... # sync wrapper +``` + +`arealtime` is a **separate runtime**, not the super-step loop. The live agent is the root; ordinary +nodes (preprocess, memory-preload, post-summarize) run as bounded phases or as callable units the live +session dispatches — never as traversed loop nodes during the live phase. + +**Forcing rule:** in a realtime graph the root/entry must be the live agent. `invoke`/`stream` on a +graph containing a `LiveAgent` node -> raise ("use `.arealtime()`"). `arealtime` on a graph with no +live agent -> raise. One live node per realtime run in v1 (multiple = mic/VAD/voice ownership +conflict; sequential transfer only, later). + +**Transparency is inherited free.** `compile()` already binds `publisher`, `callback_manager`, +`container`, `id_generator` into the InjectQ container (`state_graph.py:166-180`). `ToolNode.invoke` +(`tool_node/base.py:280`) pulls `callback_manager = Inject[CallbackManager]` and fires `publish_event` ++ `execute_before/after_invoke` + `execute_on_error` **inside itself**; `publish_event` +(`publish.py:31`) pulls `publisher = Inject[BasePublisher]` from the same container. So when `LiveAgent` +calls `ToolNode.invoke`, callbacks + publisher events fire **identically to text mode**, with zero +extra wiring — because the binding lives in the container, not the traversal. The LLM/turn-level +events that the text `Agent` node publishes are emitted by `LiveAgent` itself through the same +`publish_event` (the new realtime types — see §6). A publisher (OTEL/Kafka/Redis) on a thread sees one +continuous event stream whether the turn was text or audio. + +Pure-SDK usage (no API): + +```python +from agentflow.core.realtime import LiveInputQueue +audio = AudioAgent(model="gemini-3.1-flash-live-preview", tools=[...], container=c, ...) +graph = audio.compile(checkpointer=cp, publisher=pub) # same build/compile as ReactAgent +q = LiveInputQueue() +# feed mic frames: q.send_audio(pcm_16k); consume: +async for event in graph.arealtime(q, config={"thread_id": tid, "user_id": uid}): + if event.type == "audio_delta": play(event.bytes) + ... +``` + +## 6. Publisher / callback contract change + +Realtime emits events the current `Event`/`EventType` enums +(`agentflow/runtime/publisher/events.py`) do not cover: `interrupted`, `input_transcript`, +`output_transcript`, `go_away`, `session_resumed`, `agent_changed`. Extend the enums (add a `REALTIME` +event category + the new types) so all publisher backends inherit realtime telemetry without +per-backend changes. Existing `on_tool_start/end` callbacks fire normally. Small but real contract +change — call it out in the PR. + +## 7. State persistence — transcripts only, no audio at rest + +Audio is **not** persisted (~1.9MB/min, no retrieval value). Both-side transcripts come from Gemini +`input_audio_transcription` (user) + `output_audio_transcription` (model). On a finished transcript +turn, `LiveAgent` appends a `Message` to `AgentState` via the existing reducers (`add_messages`): + +- role `user`/`assistant`, content = `TextBlock(text=transcript)`, metadata `{modality:"audio"}`. +- No change to `AudioBlock` (whose `media: MediaRef` is required and we have no file). + +Persisted via the existing checkpointer (`aput_messages`/`aput_state`/`aput_thread`) at **VAD-turn +granularity** (not per frame). The Gemini resumption handle is stored in thread metadata. Because turns +persist as normal `Message` objects on the thread, the **same `thread_id` is continuable by the text +`ReactAgent` too** — one durable conversation across modalities. + +## 8. Resumption — two tiers, sized by context not clock + +- **Within session (socket reconnect):** on `session_update`, cache Gemini's resumption handle and + persist to thread metadata. On `go_away`/drop, reconnect with `SessionResumptionConfig(handle=...)` + under the receive loop; the caller-facing generator sees no gap. Logic lives in `LiveAgent`, not the + API. Enabled by default. +- **Cross session (durable thread resume):** on a new session with an existing `thread_id`, load + thread history and reseed it into the new live session via Gemini `send_client_content` initial + history. History is **compressed by the existing context manager** (`SummaryContextManager` / + `BaseContextManager.atrim_context`) — reseed `last-N turns + running summary` sized to the context + window, **not** a time window. Reuse `context_window_compression` (sliding window) for live growth. + Audio is never replayed (none stored). + +## 9. API layer — `/v1/graph/live` (optional, recommended) + +New router `agentflow-api/.../routers/realtime/`, separate from `/v1/graph/ws`. Thin bridge over the +SDK; not required to use realtime. + +1. Client opens WS. Auth via existing `RequirePermission("graph","stream")` + `?token=` fallback. +2. Read init JSON control frame: `{model, thread_id?, voice?, modalities?, vad?, tools_tags?, system_prompt?}`. +3. Build `AudioAgent` + `ToolNode`, load thread history, create a `LiveInputQueue`, call `arealtime`. +4. Two concurrent tasks: + - **upstream**: client WS frames -> queue. Binary frame = PCM16 audio; JSON control = + `{type:"activity_start"|"activity_end"|"text"|"close"}`. + - **downstream**: `async for event in graph.arealtime(queue, cfg)` -> client. Audio as a **binary** + frame; metadata/control (transcripts, turn_complete, interrupted, tool_call, errors) as a **JSON + text** frame (ADK bandwidth split, ~75% less than base64-in-JSON). +5. `asyncio.gather(...)`; `finally: graph.aclose()` + `queue.close()`. + +Wire protocol is provider-neutral (client never sees Gemini vs OpenAI). CLI: no new command; served by +existing `agentflow api`. + +## 10. Packaging + +- New optional extra `realtime` in core `pyproject.toml` (depends on `google-genai>=1.56.0`, already + present as the `google-genai` extra). OpenAI Realtime reuses the `openai` extra later. +- All realtime imports guarded; core import never pulls realtime deps (per CLAUDE.md optional-deps rule). + +## 11. Error handling + +- Provider `error` -> normalized `RealtimeEvent(error)`; fatal closes session, transient continues. +- Client/queue disconnect -> cancel tasks, `close()`, persist final state. +- Tool failure -> error result via `send_tool_response` so the model recovers, plus observability event. +- Reconnect failure after N attempts -> emit `error`, close. + +## 12. Testing strategy (TDD, no live LLM) + +- `FakeRealtimeClient` yields scripted `RealtimeEvent` sequences to drive `LiveAgent`/`arealtime`. +- Tool loop: fake emits `tool_call` -> assert `ToolNode.invoke` called -> `send_tool_response` with result. +- Transparency: assert publisher events + callbacks fire on the tool loop identically to text mode. +- Barge-in: `interrupted` mid-audio -> event propagated, pump task alive. +- Within-session resume: `go_away` -> reconnect with stored handle, stream continuity. +- Cross-session resume: existing `thread_id` -> history loaded, compressed via context manager, reseeded. +- Transcript persistence: finished transcripts -> `Message`(`TextBlock`+`{modality:"audio"}`) appended + via reducer + `aput_messages`; assert no audio bytes persisted. +- Forcing rule: `invoke`/`stream` on a live-rooted graph raises; `arealtime` on a non-live graph raises. +- API endpoint (mock agent): binary/JSON split, auth rejection, thread persistence. +- Coverage stays >= 70%. + +## 13. Future (explicit, out of v1; architecture must not block) + +- **Case 1 — text sub-agent as agent-as-tool.** A self-contained `CompiledGraph` invoked as a tool by + the live session (framework has no subgraph nesting; one top graph, many nodes). Socket held open; + use Gemini `NON_BLOCKING` scheduling so the model can say "one moment" while it runs. Static routing + = phases; dynamic mid-voice routing must surface as a tool/handoff (the only provider control point). +- **Case 2 — realtime sub-agent as a persona swap on ONE shared socket** (ADK model): transfer swaps + system instruction + tool set, `agent_changed` author tag updates, same audio stream. Per-agent voice + via agent-level `speech_config`. +- **Rejected — two concurrent provider sockets** (Case 3): two audio streams/VADs/voices fighting one + mic; provider one-voice/one-modality per session. Not viable; use sequential transfer instead. +- OpenAI Realtime provider behind the same `RealtimeClient`. +- TypeScript client SDK / `useRealtime` hook / browser audio I/O. +- Ephemeral-token browser-direct (only ever a no-tools mode; agentic requires proxy). + +## 14. Build phases + +- **Phase 1**: `base.py` contracts + `GeminiLiveClient` + normalizer + `LiveInputQueue`. Unit tests vs fake socket. +- **Phase 2**: `LiveAgent` (subclass base, duplex loop, tool loop, transcript persist, resume) + + `RealtimeConfig` + `AudioAgent` prebuilt. Tool/barge-in/resume/transparency/persistence tests. +- **Phase 3**: `CompiledGraph.arealtime`/`realtime` + forcing-rule guards + publisher enum extension. + Pure-SDK end-to-end test with fake provider. +- **Phase 4** (optional surface): `/v1/graph/live` API endpoint, auth reuse, binary/JSON frame split. +- **Phase 5** (future): OpenAI provider; multi-agent (Case 1 + 2); TS client; ephemeral tokens. diff --git a/examples/realtime/README.md b/examples/realtime/README.md new file mode 100644 index 00000000..7e85c14a --- /dev/null +++ b/examples/realtime/README.md @@ -0,0 +1,79 @@ +# Realtime audio-to-audio (Gemini Live) + +These examples use `AudioAgent`, the prebuilt realtime agent. Unlike `invoke`/`stream` +(turn-based super-step traversal), a realtime graph is driven by a separate runtime, +`CompiledGraph.arealtime(input_queue, config)`, because the provider owns the turn loop. + +- Input audio: PCM16, mono, 16 kHz. +- Output audio: PCM16, mono, 24 kHz. +- Transcripts are persisted as `Message`s (`metadata={"modality": "audio"}`); raw audio is + never stored. + +## Install + +```bash +pip install "10xscale-agentflow[realtime]" # pulls in google-genai +export GEMINI_API_KEY=... +# Optional: pick a Gemini Live model (defaults to gemini-live-2.5-flash-preview). +# Valid Live model names come from the google-genai SDK, e.g.: +# gemini-live-2.5-flash-preview +# gemini-2.0-flash-live-preview-04-09 +# Check Google's current docs for availability in your region. +export GEMINI_LIVE_MODEL=gemini-live-2.5-flash-preview +``` + +## 1. Headless: WAV file in, WAV file out + +No microphone or speaker needed. Good for a first run / CI. + +```bash +python examples/realtime/audio_agent_file.py path/to/input.wav # 16 kHz mono PCM16 +# writes out.wav and prints transcripts + tool calls +``` + +## 2. Live microphone (full duplex) + +Speak and the agent talks back, with barge-in and tool calling. + +```bash +pip install sounddevice +python examples/realtime/audio_agent_mic.py +# speak; Ctrl+C to stop +``` + +## 3. Through the API server (`/v1/graph/live` WebSocket) + +```bash +cd examples/realtime +agentflow api +# connect a WebSocket client to ws://localhost:8000/v1/graph/live +``` + +Protocol: + +- First frame: a JSON control frame, e.g. `{"model": "...", "thread_id": "abc", "voice": "Puck"}`. + Present fields override the agent's build-time config for that session. +- Upstream: binary frame = PCM16 input audio; JSON control frame = + `{"type": "text" | "activity_start" | "activity_end" | "close", ...}`. +- Downstream: binary frame = PCM16 model audio; JSON text frame = every other event + (transcripts, `turn_complete`, `interrupted`, `tool_call`, session/`go_away`, `error`). + +## Key APIs + +```python +from agentflow.core.realtime.base import RealtimeConfig +from agentflow.core.realtime.queue import LiveInputQueue +from agentflow.prebuilt.agent import AudioAgent + +app = AudioAgent( + "gemini-live-2.5-flash-preview", + realtime_config=RealtimeConfig(model="gemini-live-2.5-flash-preview", voice="Puck"), + tools=[my_tool], # advertised to the model automatically +).compile() + +queue = LiveInputQueue() +queue.send_audio(pcm16_bytes) # non-blocking; safe to call from an audio callback +async for event in app.arealtime(queue, {"thread_id": "t1"}): + ... # AudioDeltaEvent / transcripts / ToolCallEvent / ... +queue.close() # ends the session once the provider goes idle +``` diff --git a/examples/realtime/__init__.py b/examples/realtime/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/realtime/agentflow.json b/examples/realtime/agentflow.json new file mode 100644 index 00000000..0f5595e7 --- /dev/null +++ b/examples/realtime/agentflow.json @@ -0,0 +1,8 @@ +{ + "dependencies": ["."], + "graphs": { + "agent": "./graph.py:app", + "checkpointer": "./graph.py:checkpointer" + }, + "env": ".env" +} diff --git a/examples/realtime/audio_agent_file.py b/examples/realtime/audio_agent_file.py new file mode 100644 index 00000000..bbc8f4ba --- /dev/null +++ b/examples/realtime/audio_agent_file.py @@ -0,0 +1,106 @@ +"""Realtime audio-to-audio from a WAV file, using AudioAgent + Gemini Live. + +No microphone or speaker required. This streams a 16 kHz mono PCM16 WAV into the live +session, writes the model's 24 kHz audio reply to ``out.wav``, and prints the input and +output transcripts plus any tool calls. It is the headless counterpart to +``audio_agent_mic.py`` and is the easiest way to sanity-check your setup. + +Setup + pip install "10xscale-agentflow[realtime]" + export GEMINI_API_KEY=... + # optionally override the model (see README for valid Gemini Live models): + export GEMINI_LIVE_MODEL=gemini-live-2.5-flash-preview + +Run + python examples/realtime/audio_agent_file.py path/to/input.wav + +``input.wav`` must be mono, 16-bit PCM, 16 kHz (the format Gemini Live expects for input). +""" + +import asyncio +import os +import sys +import wave + +from dotenv import load_dotenv + +from agentflow.core.realtime.base import OUTPUT_SAMPLE_RATE, RealtimeConfig +from agentflow.core.realtime.queue import LiveInputQueue +from agentflow.prebuilt.agent import AudioAgent + + +load_dotenv() + +MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-live-2.5-flash-preview") + + +def get_weather(location: str) -> str: + """Return the current weather for a city. Called by the model during the conversation.""" + return f"It is 22 degrees Celsius and sunny in {location}." + + +def build_app(): + """Compile a single realtime audio agent with one tool and a voice.""" + config = RealtimeConfig( + model=MODEL, + voice="Puck", + system_instruction="You are a concise voice assistant. Keep answers to one or two sentences.", + ) + return AudioAgent(MODEL, realtime_config=config, tools=[get_weather]).compile() + + +def read_pcm16(path: str) -> tuple[int, bytes]: + """Read a mono PCM16 WAV file, returning (sample_rate, raw_pcm_bytes).""" + with wave.open(path, "rb") as wf: + if wf.getsampwidth() != 2 or wf.getnchannels() != 1: + raise ValueError("input must be mono, 16-bit PCM (sampwidth=2, channels=1)") + return wf.getframerate(), wf.readframes(wf.getnframes()) + + +async def feed_audio(queue: LiveInputQueue, pcm: bytes, sample_rate: int) -> None: + """Stream the file into the session in ~100 ms chunks (the audio hot path).""" + chunk = (sample_rate // 10) * 2 # 100 ms * 2 bytes/sample + for offset in range(0, len(pcm), chunk): + queue.send_audio(pcm[offset : offset + chunk], sample_rate=sample_rate) + await asyncio.sleep(0.0) # yield so the pump task can flush to the socket + # Automatic VAD detects end-of-speech; we leave the queue open to receive the reply + # and close it from the main loop once the model finishes its turn. + + +async def main() -> None: + in_path = sys.argv[1] if len(sys.argv) > 1 else "input.wav" + if not os.path.exists(in_path): + sys.exit(f"Provide a 16 kHz mono PCM16 WAV path. Not found: {in_path}") + + sample_rate, pcm = read_pcm16(in_path) + app = build_app() + queue = LiveInputQueue() + + out = wave.open("out.wav", "wb") + out.setnchannels(1) + out.setsampwidth(2) + out.setframerate(OUTPUT_SAMPLE_RATE) + + feeder = asyncio.create_task(feed_audio(queue, pcm, sample_rate)) + try: + async for event in app.arealtime(queue, {"thread_id": "audio-file-demo"}): + if event.type == "audio_delta": + out.writeframes(event.data) + elif event.type == "input_transcript" and event.finished: + print(f"you: {event.text}") + elif event.type == "output_transcript" and event.finished: + print(f"agent: {event.text}") + elif event.type == "tool_call": + print(f"[tool] {event.name}({event.args})") + elif event.type == "turn_complete": + queue.close() # one turn for this demo: end the session + finally: + feeder.cancel() + out.close() + await app.aclose() + + print("Wrote model audio reply to out.wav") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/realtime/audio_agent_mic.py b/examples/realtime/audio_agent_mic.py new file mode 100644 index 00000000..1bb96c16 --- /dev/null +++ b/examples/realtime/audio_agent_mic.py @@ -0,0 +1,104 @@ +"""Live microphone audio-to-audio with AudioAgent + Gemini Live. + +Speak into your microphone and the agent talks back in real time. It supports barge-in +(start talking while it is speaking and it stops to listen) and tool calls. This is the +full duplex demo; for a headless/no-hardware version see ``audio_agent_file.py``. + +Setup + pip install "10xscale-agentflow[realtime]" sounddevice + export GEMINI_API_KEY=... + export GEMINI_LIVE_MODEL=gemini-live-2.5-flash-preview # optional, see README + +Run + python examples/realtime/audio_agent_mic.py + # speak; press Ctrl+C to stop. +""" + +import asyncio +import os +import sys + +from dotenv import load_dotenv + +from agentflow.core.realtime.base import INPUT_SAMPLE_RATE, OUTPUT_SAMPLE_RATE, RealtimeConfig +from agentflow.core.realtime.queue import LiveInputQueue +from agentflow.prebuilt.agent import AudioAgent + + +load_dotenv() + +MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-live-2.5-flash-preview") +MIC_BLOCK = INPUT_SAMPLE_RATE // 10 # 100 ms frames + + +def get_weather(location: str) -> str: + """Return the current weather for a city. Called by the model during the conversation.""" + return f"It is 22 degrees Celsius and sunny in {location}." + + +def build_app(): + config = RealtimeConfig( + model=MODEL, + voice="Puck", + system_instruction="You are a friendly, concise voice assistant.", + ) + return AudioAgent(MODEL, realtime_config=config, tools=[get_weather]).compile() + + +async def main() -> None: + try: + import sounddevice as sd + except ImportError: + sys.exit("This example needs sounddevice: pip install sounddevice") + + app = build_app() + queue = LiveInputQueue() + loop = asyncio.get_running_loop() + + def on_mic(indata, _frames, _time, _status) -> None: + # PortAudio calls this on its own thread; marshal onto the event loop so the + # asyncio-backed queue is touched only from the loop thread. + loop.call_soon_threadsafe(queue.send_audio, bytes(indata)) + + speaker = sd.RawOutputStream(samplerate=OUTPUT_SAMPLE_RATE, channels=1, dtype="int16") + mic = sd.RawInputStream( + samplerate=INPUT_SAMPLE_RATE, + channels=1, + dtype="int16", + blocksize=MIC_BLOCK, + callback=on_mic, + ) + speaker.start() + mic.start() + print("Listening. Speak into your mic; press Ctrl+C to stop.") + + try: + async for event in app.arealtime(queue, {"thread_id": "audio-mic-demo"}): + if event.type == "audio_delta": + speaker.write(event.data) + elif event.type == "interrupted": + # Barge-in: discard audio already queued for playback. + speaker.stop() + speaker.start() + elif event.type == "input_transcript" and event.finished: + print(f"you: {event.text}") + elif event.type == "output_transcript" and event.finished: + print(f"agent: {event.text}") + elif event.type == "tool_call": + print(f"[tool] {event.name}({event.args})") + except (KeyboardInterrupt, asyncio.CancelledError): + print("\nStopping...") + finally: + queue.close() + mic.stop() + mic.close() + speaker.stop() + speaker.close() + await app.aclose() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/examples/realtime/graph.py b/examples/realtime/graph.py new file mode 100644 index 00000000..36c7c2b1 --- /dev/null +++ b/examples/realtime/graph.py @@ -0,0 +1,43 @@ +"""Realtime AudioAgent exposed through the Agentflow API server. + +``agentflow.json`` points the server at ``app`` below. Once running, the server serves a +WebSocket at ``/v1/graph/live`` that bridges browser/client audio to this agent (binary +PCM16 frames upstream, model audio back as binary, transcripts/tool-calls/events as JSON). + +Run + cd examples/realtime + export GEMINI_API_KEY=... + agentflow api + # then connect a WebSocket client to ws://localhost:8000/v1/graph/live +""" + +import os + +from dotenv import load_dotenv + +from agentflow.core.realtime.base import RealtimeConfig +from agentflow.prebuilt.agent import AudioAgent +from agentflow.storage.checkpointer import InMemoryCheckpointer + + +load_dotenv() + +MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-live-2.5-flash-preview") + + +def get_weather(location: str) -> str: + """Return the current weather for a city. Called by the model during the conversation.""" + return f"It is 22 degrees Celsius and sunny in {location}." + + +checkpointer = InMemoryCheckpointer() + +app = AudioAgent( + MODEL, + realtime_config=RealtimeConfig( + model=MODEL, + voice="Puck", + system_instruction="You are a concise, friendly voice assistant.", + ), + tools=[get_weather], +).compile(checkpointer=checkpointer) diff --git a/tests/publisher/test_events.py b/tests/publisher/test_events.py index 4ed18ad2..3e076274 100644 --- a/tests/publisher/test_events.py +++ b/tests/publisher/test_events.py @@ -32,10 +32,12 @@ def test_event_enum_values(self): assert Event.TOOL_EXECUTION.value == "tool_execution" assert Event.STREAMING.value == "streaming" + assert Event.REALTIME.value == "realtime" + # Test all expected values are present expected_values = { "graph_execution", "node_execution", "llm_call", - "tool_execution", "streaming" + "tool_execution", "streaming", "realtime" } actual_values = {event.value for event in Event} assert actual_values == expected_values @@ -73,12 +75,13 @@ def test_content_type_enum_values(self): assert ContentType.STATE.value == "state" assert ContentType.UPDATE.value == "update" assert ContentType.ERROR.value == "error" - + assert ContentType.TRANSCRIPT.value == "transcript" + # Test all expected values are present expected_values = { "text", "message", "reasoning", "tool_call", "tool_result", - "image", "audio", "video", "document", "data", - "state", "update", "error" + "image", "audio", "video", "document", "data", + "state", "update", "error", "transcript" } actual_values = {content_type.value for content_type in ContentType} assert actual_values == expected_values diff --git a/tests/realtime/__init__.py b/tests/realtime/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/realtime/test_arealtime.py b/tests/realtime/test_arealtime.py new file mode 100644 index 00000000..98452865 --- /dev/null +++ b/tests/realtime/test_arealtime.py @@ -0,0 +1,114 @@ +"""Phase 3: CompiledGraph.arealtime/realtime runtime + forcing-rule guards. + +Pure-SDK end-to-end: AudioAgent -> compile -> arealtime drives a fake provider socket, +no server and no live LLM involved. +""" + +import pytest + +from agentflow.core.realtime.base import ( + AudioDeltaEvent, + InputTranscriptEvent, + OutputTranscriptEvent, + ToolCallEvent, + TurnCompleteEvent, +) +from agentflow.core.realtime.queue import LiveInputQueue +from agentflow.prebuilt.agent import AudioAgent, ReactAgent +from agentflow.storage.checkpointer import InMemoryCheckpointer +from tests.realtime.test_live_agent import FakeRealtimeClient, _factory + +MODEL = "gemini-2.5-flash-live" + + +class TestArealtimeRuntime: + @pytest.mark.asyncio + async def test_arealtime_drives_live_agent_end_to_end(self): + client = FakeRealtimeClient([AudioDeltaEvent(data=b"\x01"), TurnCompleteEvent()]) + app = AudioAgent(MODEL, realtime_client_factory=_factory(client)).compile() + + q = LiveInputQueue() + q.close() + events = [e async for e in app.arealtime(q, {"thread_id": "t1"})] + + assert [e.type for e in events] == ["audio_delta", "turn_complete"] + + @pytest.mark.asyncio + async def test_arealtime_persists_transcripts_through_compiled_checkpointer(self): + client = FakeRealtimeClient( + [ + InputTranscriptEvent(text="hello", finished=True), + OutputTranscriptEvent(text="hi", finished=True), + TurnCompleteEvent(), + ] + ) + cp = InMemoryCheckpointer() + app = AudioAgent(MODEL, realtime_client_factory=_factory(client)).compile(checkpointer=cp) + + config = {"thread_id": "t-cp", "user_id": "u1"} + async for _ in app.arealtime(q := LiveInputQueue(), config): + q.close() # close after first event so the loop can finish + + persisted = await cp.alist_messages(config) + assert {m.content[0].text for m in persisted} == {"hello", "hi"} + + @pytest.mark.asyncio + async def test_arealtime_tool_loop_uses_compiled_toolnode(self): + def get_time() -> str: + return "12:00" + + client = FakeRealtimeClient([ToolCallEvent(id="c1", name="get_time", args={})]) + app = AudioAgent( + MODEL, tools=[get_time], realtime_client_factory=_factory(client) + ).compile() + + q = LiveInputQueue() + q.close() + events = [e async for e in app.arealtime(q, {"thread_id": "t1"})] + + assert any(e.type == "tool_result" and e.result == {"result": "12:00"} for e in events) + assert client.tool_responses[0][1] == "get_time" + + +class TestForcingRule: + @pytest.mark.asyncio + async def test_invoke_on_live_graph_raises(self): + app = AudioAgent(MODEL, realtime_client_factory=_factory(FakeRealtimeClient())).compile() + with pytest.raises(RuntimeError, match="arealtime"): + await app.ainvoke({"messages": []}, {"thread_id": "t1"}) + + @pytest.mark.asyncio + async def test_astream_on_live_graph_raises(self): + app = AudioAgent(MODEL, realtime_client_factory=_factory(FakeRealtimeClient())).compile() + with pytest.raises(RuntimeError, match="arealtime"): + async for _ in app.astream({"messages": []}, {"thread_id": "t1"}): + pass + + @pytest.mark.asyncio + async def test_arealtime_on_non_live_graph_raises(self): + app = ReactAgent(model="gemini-2.5-flash").compile() + q = LiveInputQueue() + q.close() + with pytest.raises(RuntimeError, match="LiveAgent"): + async for _ in app.arealtime(q, {"thread_id": "t1"}): + pass + + +class TestRealtimeSyncWrapper: + def test_realtime_sync_iteration(self): + client = FakeRealtimeClient([AudioDeltaEvent(data=b"\x02"), TurnCompleteEvent()]) + app = AudioAgent(MODEL, realtime_client_factory=_factory(client)).compile() + + q = LiveInputQueue() + q.close() + events = list(app.realtime(q, {"thread_id": "t1"})) + + assert [e.type for e in events] == ["audio_delta", "turn_complete"] + + @pytest.mark.asyncio + async def test_realtime_sync_rejected_inside_running_loop(self): + app = AudioAgent(MODEL, realtime_client_factory=_factory(FakeRealtimeClient())).compile() + q = LiveInputQueue() + q.close() + with pytest.raises(RuntimeError, match="running event loop"): + list(app.realtime(q, {"thread_id": "t1"})) diff --git a/tests/realtime/test_audio_agent.py b/tests/realtime/test_audio_agent.py new file mode 100644 index 00000000..61b84ffc --- /dev/null +++ b/tests/realtime/test_audio_agent.py @@ -0,0 +1,54 @@ +"""Unit tests for the AudioAgent prebuilt (wraps LiveAgent as a compiled graph root).""" + +import pytest + +from agentflow.core.graph.compiled_graph import CompiledGraph +from agentflow.core.realtime.live_agent import LiveAgent +from agentflow.prebuilt.agent.audio import AudioAgent + +MODEL = "gemini-2.5-flash-live" + + +class TestAudioAgentBuild: + def test_compile_returns_compiled_graph_with_live_root(self): + agent = AudioAgent(MODEL) + app = agent.compile() + + assert isinstance(app, CompiledGraph) + live_node = app._state_graph.nodes[agent._live_node_name] + assert isinstance(live_node.func, LiveAgent) + assert app._state_graph.entry_point == agent._live_node_name + + def test_tools_are_wired_into_the_live_agent(self): + def get_weather(city: str) -> str: + return f"sunny in {city}" + + agent = AudioAgent(MODEL, tools=[get_weather]) + agent.compile() + + assert agent._agent._resolve_tool_node() is not None + + def test_realtime_config_passthrough(self): + from agentflow.core.realtime.base import RealtimeConfig + + cfg = RealtimeConfig(model=MODEL, voice="Puck", response_modalities=["AUDIO"]) + agent = AudioAgent(MODEL, realtime_config=cfg) + + assert agent._agent.realtime_config.voice == "Puck" + + @pytest.mark.asyncio + async def test_compiled_live_agent_runs_via_arun(self): + # End-to-end-ish: the LiveAgent inside the compiled graph drives a fake socket. + from agentflow.core.realtime.base import AudioDeltaEvent, TurnCompleteEvent + from agentflow.core.realtime.queue import LiveInputQueue + from tests.realtime.test_live_agent import FakeRealtimeClient, _factory + + client = FakeRealtimeClient([AudioDeltaEvent(data=b"\x01"), TurnCompleteEvent()]) + agent = AudioAgent(MODEL, realtime_client_factory=_factory(client)) + agent.compile() + + q = LiveInputQueue() + q.close() + events = [e async for e in agent._agent.arun(q, {"thread_id": "t1"})] + + assert [e.type for e in events] == ["audio_delta", "turn_complete"] diff --git a/tests/realtime/test_base.py b/tests/realtime/test_base.py new file mode 100644 index 00000000..075c18b7 --- /dev/null +++ b/tests/realtime/test_base.py @@ -0,0 +1,108 @@ +"""Unit tests for provider-neutral realtime contracts (agentflow.core.realtime.base).""" + +import pytest +from pydantic import TypeAdapter, ValidationError + +from agentflow.core.realtime.base import ( + AgentChangedEvent, + AudioDeltaEvent, + ErrorEvent, + GoAwayEvent, + InputTranscriptEvent, + InterruptedEvent, + OutputTranscriptEvent, + RealtimeConfig, + RealtimeEvent, + SessionUpdateEvent, + ToolCallEvent, + ToolResultEvent, + TurnCompleteEvent, +) + + +class TestRealtimeEventDiscrimination: + def test_audio_delta_carries_pcm_and_sample_rate(self): + event = AudioDeltaEvent(data=b"\x00\x01", sample_rate=24000) + assert event.type == "audio_delta" + assert event.data == b"\x00\x01" + assert event.sample_rate == 24000 + + def test_input_and_output_transcripts_track_finished_flag(self): + user = InputTranscriptEvent(text="hello", finished=False) + model = OutputTranscriptEvent(text="hi there", finished=True) + assert user.type == "input_transcript" + assert user.finished is False + assert model.type == "output_transcript" + assert model.finished is True + + def test_tool_call_and_result_pair_by_id(self): + call = ToolCallEvent(id="c1", name="get_weather", args={"city": "Paris"}) + result = ToolResultEvent(id="c1", result={"temp": 20}) + assert call.type == "tool_call" + assert call.name == "get_weather" + assert call.args == {"city": "Paris"} + assert result.type == "tool_result" + assert result.id == call.id + + def test_lifecycle_events_have_no_required_payload(self): + assert TurnCompleteEvent().type == "turn_complete" + assert InterruptedEvent().type == "interrupted" + + def test_session_and_goaway_carry_resume_metadata(self): + update = SessionUpdateEvent(resumption_handle="abc123") + goaway = GoAwayEvent(time_left="5s") + assert update.type == "session_update" + assert update.resumption_handle == "abc123" + assert goaway.type == "go_away" + assert goaway.time_left == "5s" + + def test_agent_changed_and_error(self): + changed = AgentChangedEvent(author="planner") + err = ErrorEvent(code="quota", message="rate limited") + assert changed.type == "agent_changed" + assert changed.author == "planner" + assert err.type == "error" + assert err.code == "quota" + assert err.message == "rate limited" + + def test_union_deserializes_by_type_discriminator(self): + adapter = TypeAdapter(RealtimeEvent) + parsed = adapter.validate_python({"type": "interrupted"}) + assert isinstance(parsed, InterruptedEvent) + parsed_call = adapter.validate_python( + {"type": "tool_call", "id": "x", "name": "f", "args": {}} + ) + assert isinstance(parsed_call, ToolCallEvent) + + def test_union_rejects_unknown_type(self): + adapter = TypeAdapter(RealtimeEvent) + with pytest.raises(ValidationError): + adapter.validate_python({"type": "not_a_real_event"}) + + +class TestRealtimeConfig: + def test_minimal_config_requires_only_model(self): + config = RealtimeConfig(model="gemini-2.5-flash-live") + assert config.model == "gemini-2.5-flash-live" + # Audio-out by default for an audio agent. + assert config.response_modalities == ["AUDIO"] + + def test_single_response_modality_enforced(self): + # Gemini Live allows exactly one response modality per session. + with pytest.raises(ValidationError): + RealtimeConfig(model="m", response_modalities=["AUDIO", "TEXT"]) + + def test_full_config_round_trip(self): + config = RealtimeConfig( + model="gemini-2.5-flash-live", + response_modalities=["TEXT"], + voice="Puck", + system_instruction="be terse", + input_audio_transcription=True, + output_audio_transcription=True, + session_resumption=True, + tools_tags=["weather"], + ) + assert config.voice == "Puck" + assert config.input_audio_transcription is True + assert config.tools_tags == ["weather"] diff --git a/tests/realtime/test_gemini_live.py b/tests/realtime/test_gemini_live.py new file mode 100644 index 00000000..cf9706c0 --- /dev/null +++ b/tests/realtime/test_gemini_live.py @@ -0,0 +1,390 @@ +"""Unit tests for GeminiLiveClient and its LiveServerMessage normalizer. + +No live LLM and no real socket: a FakeLiveSession records sent frames and yields +scripted server messages. Server-message normalization is duck-typed (reads attributes) +so we drive it with lightweight stand-ins shaped like google.genai's LiveServerMessage. +""" + +from types import SimpleNamespace + +import pytest + +from agentflow.core.realtime.base import ( + AudioDeltaEvent, + ErrorEvent, + GoAwayEvent, + InputTranscriptEvent, + InterruptedEvent, + OutputTranscriptEvent, + RealtimeConfig, + SessionUpdateEvent, + ToolCallEvent, + TurnCompleteEvent, +) +from agentflow.core.realtime.providers.gemini_live import GeminiLiveClient, normalize_message + + +# --------------------------------------------------------------------------- # +# Message-shape helpers (mirror google.genai LiveServerMessage attribute names) +# --------------------------------------------------------------------------- # +def _msg(**kw): + base = dict( + server_content=None, + tool_call=None, + go_away=None, + session_resumption_update=None, + ) + base.update(kw) + return SimpleNamespace(**base) + + +def _server_content(**kw): + base = dict( + model_turn=None, + input_transcription=None, + output_transcription=None, + interrupted=None, + generation_complete=None, + turn_complete=None, + ) + base.update(kw) + return SimpleNamespace(**base) + + +def _audio_part(data: bytes, mime: str = "audio/pcm;rate=24000"): + inline = SimpleNamespace(data=data, mime_type=mime) + return SimpleNamespace(inline_data=inline, text=None) + + +class TestNormalizeMessage: + def test_audio_part_becomes_audio_delta(self): + content = _server_content( + model_turn=SimpleNamespace(parts=[_audio_part(b"\x10\x20")]) + ) + events = normalize_message(_msg(server_content=content)) + assert len(events) == 1 + assert isinstance(events[0], AudioDeltaEvent) + assert events[0].data == b"\x10\x20" + assert events[0].sample_rate == 24000 + + def test_input_and_output_transcription(self): + content = _server_content( + input_transcription=SimpleNamespace(text="hi", finished=False), + output_transcription=SimpleNamespace(text="hello", finished=True), + ) + events = normalize_message(_msg(server_content=content)) + assert any( + isinstance(e, InputTranscriptEvent) and e.text == "hi" and e.finished is False + for e in events + ) + assert any( + isinstance(e, OutputTranscriptEvent) and e.text == "hello" and e.finished is True + for e in events + ) + + def test_interrupted_and_generation_complete(self): + content = _server_content(interrupted=True, generation_complete=True) + events = normalize_message(_msg(server_content=content)) + assert any(isinstance(e, InterruptedEvent) for e in events) + assert any(isinstance(e, TurnCompleteEvent) for e in events) + + def test_tool_call_function_calls(self): + fc = SimpleNamespace(id="call-1", name="get_weather", args={"city": "Paris"}) + tool_call = SimpleNamespace(function_calls=[fc]) + events = normalize_message(_msg(tool_call=tool_call)) + assert len(events) == 1 + assert isinstance(events[0], ToolCallEvent) + assert events[0].id == "call-1" + assert events[0].name == "get_weather" + assert events[0].args == {"city": "Paris"} + + def test_tool_call_synthesizes_id_when_missing(self): + fc = SimpleNamespace(id=None, name="f", args=None) + events = normalize_message(_msg(tool_call=SimpleNamespace(function_calls=[fc]))) + assert isinstance(events[0], ToolCallEvent) + assert events[0].id # non-empty fallback id + assert events[0].args == {} + + def test_session_resumption_update(self): + upd = SimpleNamespace(new_handle="h-123", resumable=True) + events = normalize_message(_msg(session_resumption_update=upd)) + assert len(events) == 1 + assert isinstance(events[0], SessionUpdateEvent) + assert events[0].resumption_handle == "h-123" + + def test_go_away(self): + events = normalize_message(_msg(go_away=SimpleNamespace(time_left="5s"))) + assert isinstance(events[0], GoAwayEvent) + assert events[0].time_left == "5s" + + def test_empty_message_yields_nothing(self): + assert normalize_message(_msg()) == [] + + +# --------------------------------------------------------------------------- # +# Fake session / connector for client lifecycle and send-mapping tests +# --------------------------------------------------------------------------- # +class FakeLiveSession: + def __init__(self, scripted=None): + self.scripted = scripted or [] + self.sent_realtime = [] + self.tool_responses = [] + self.closed = False + + async def send_realtime_input(self, **kwargs): + self.sent_realtime.append(kwargs) + + async def send_tool_response(self, **kwargs): + self.tool_responses.append(kwargs) + + async def send_client_content(self, **kwargs): + self.client_content = kwargs + + async def receive(self): + for m in self.scripted: + yield m + + +class FakeConnector: + """Stands in for client.aio.live.connect(...) -> async context manager.""" + + def __init__(self, session): + self.session = session + self.enter_calls = [] + self.exited = False + + def __call__(self, *, model, config): + self.enter_calls.append({"model": model, "config": config}) + return self + + async def __aenter__(self): + return self.session + + async def __aexit__(self, *exc): + self.exited = True + return False + + +@pytest.fixture +def config(): + return RealtimeConfig(model="gemini-2.5-flash-live", voice="Puck") + + +class TestGeminiLiveClientLifecycle: + @pytest.mark.asyncio + async def test_connect_opens_session_via_connector(self, config): + session = FakeLiveSession() + connector = FakeConnector(session) + client = GeminiLiveClient(connector=connector) + + await client.connect(config) + + assert connector.enter_calls[0]["model"] == "gemini-2.5-flash-live" + assert client.connected is True + + @pytest.mark.asyncio + async def test_connect_with_resume_handle_sets_session_resumption(self, config): + connector = FakeConnector(FakeLiveSession()) + client = GeminiLiveClient(connector=connector) + + await client.connect(config, resume_handle="handle-xyz") + + live_config = connector.enter_calls[0]["config"] + assert live_config.session_resumption.handle == "handle-xyz" + + @pytest.mark.asyncio + async def test_reseed_history_maps_messages_to_send_client_content(self, config): + from agentflow.core.state import Message + + session = FakeLiveSession() + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + await client.reseed_history( + [ + Message.text_message("hi", role="user"), + Message.text_message("hello", role="assistant"), + ] + ) + + turns = session.client_content["turns"] + assert [t.role for t in turns] == ["user", "model"] + assert session.client_content["turn_complete"] is True + + @pytest.mark.asyncio + async def test_close_exits_context_manager_and_is_idempotent(self, config): + session = FakeLiveSession() + connector = FakeConnector(session) + client = GeminiLiveClient(connector=connector) + await client.connect(config) + + await client.close() + await client.close() # must not raise + + assert connector.exited is True + assert client.connected is False + + +class TestBuildConnectConfig: + @pytest.mark.asyncio + async def test_voice_and_transcription_mapped_into_live_config(self, config): + connector = FakeConnector(FakeLiveSession()) + client = GeminiLiveClient(connector=connector) + await client.connect(config) + + live_config = connector.enter_calls[0]["config"] + assert [m.value for m in live_config.response_modalities] == ["AUDIO"] + assert live_config.speech_config is not None + assert live_config.input_audio_transcription is not None + assert live_config.output_audio_transcription is not None + assert live_config.session_resumption is not None + + @pytest.mark.asyncio + async def test_disabled_vad_sets_manual_activity_detection(self): + from agentflow.core.realtime.base import RealtimeConfig, VADConfig + + cfg = RealtimeConfig(model="m", vad=VADConfig(enabled=False)) + connector = FakeConnector(FakeLiveSession()) + client = GeminiLiveClient(connector=connector) + await client.connect(cfg) + + live_config = connector.enter_calls[0]["config"] + assert live_config.realtime_input_config.automatic_activity_detection.disabled is True + + @pytest.mark.asyncio + async def test_context_window_compression_enabled(self): + from agentflow.core.realtime.base import RealtimeConfig + + cfg = RealtimeConfig(model="m", context_window_compression=True) + connector = FakeConnector(FakeLiveSession()) + client = GeminiLiveClient(connector=connector) + await client.connect(cfg) + + live_config = connector.enter_calls[0]["config"] + assert live_config.context_window_compression is not None + + +class TestBuildConnectConfigTools: + @pytest.mark.asyncio + async def test_openai_tool_dicts_become_function_declarations(self): + cfg = RealtimeConfig( + model="m", + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + } + ], + ) + connector = FakeConnector(FakeLiveSession()) + client = GeminiLiveClient(connector=connector) + await client.connect(cfg) + + tools = connector.enter_calls[0]["config"].tools + assert len(tools) == 1 # one Tool wrapping the declarations + decls = tools[0].function_declarations + assert [d.name for d in decls] == ["get_weather"] + + @pytest.mark.asyncio + async def test_non_dict_tools_pass_through_untouched(self): + def raw_callable(): + """A raw callable tool.""" + + cfg = RealtimeConfig(model="m", tools=[raw_callable]) + connector = FakeConnector(FakeLiveSession()) + client = GeminiLiveClient(connector=connector) + await client.connect(cfg) + + tools = connector.enter_calls[0]["config"].tools + assert tools == [raw_callable] + + +class TestGeminiLiveClientSend: + @pytest.mark.asyncio + async def test_send_audio_maps_to_blob_with_pcm_mime(self, config): + session = FakeLiveSession() + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + await client.send_audio(b"\xaa\xbb", sample_rate=16000) + + assert len(session.sent_realtime) == 1 + blob = session.sent_realtime[0]["audio"] + assert blob.data == b"\xaa\xbb" + assert blob.mime_type == "audio/pcm;rate=16000" + + @pytest.mark.asyncio + async def test_activity_markers_map_to_realtime_input(self, config): + session = FakeLiveSession() + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + await client.send_activity_start() + await client.send_activity_end() + + assert "activity_start" in session.sent_realtime[0] + assert "activity_end" in session.sent_realtime[1] + + @pytest.mark.asyncio + async def test_send_tool_response_maps_to_function_response(self, config): + session = FakeLiveSession() + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + await client.send_tool_response("call-1", "get_weather", {"temp": 20}) + + responses = session.tool_responses[0]["function_responses"] + assert responses[0].id == "call-1" + assert responses[0].name == "get_weather" + assert responses[0].response == {"temp": 20} + + @pytest.mark.asyncio + async def test_send_text_maps_to_realtime_input(self, config): + session = FakeLiveSession() + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + await client.send_text("hello there") + + assert session.sent_realtime[0]["text"] == "hello there" + + @pytest.mark.asyncio + async def test_send_before_connect_raises(self, config): + client = GeminiLiveClient(connector=FakeConnector(FakeLiveSession())) + with pytest.raises(RuntimeError): + await client.send_audio(b"\x00", sample_rate=16000) + + +class TestGeminiLiveClientReceive: + @pytest.mark.asyncio + async def test_receive_normalizes_scripted_messages_in_order(self, config): + scripted = [ + _msg( + server_content=_server_content( + model_turn=SimpleNamespace(parts=[_audio_part(b"\x01")]) + ) + ), + _msg(server_content=_server_content(turn_complete=True)), + ] + session = FakeLiveSession(scripted=scripted) + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + events = [e async for e in client.receive()] + + assert isinstance(events[0], AudioDeltaEvent) + assert isinstance(events[1], TurnCompleteEvent) + + @pytest.mark.asyncio + async def test_receive_before_connect_raises(self, config): + client = GeminiLiveClient(connector=FakeConnector(FakeLiveSession())) + with pytest.raises(RuntimeError): + async for _ in client.receive(): + pass diff --git a/tests/realtime/test_live_agent.py b/tests/realtime/test_live_agent.py new file mode 100644 index 00000000..c2f18cdb --- /dev/null +++ b/tests/realtime/test_live_agent.py @@ -0,0 +1,456 @@ +"""Unit tests for LiveAgent's duplex loop, tool loop, persistence, and resumption. + +No live LLM and no real socket: a FakeRealtimeClient yields scripted RealtimeEvents and +records everything sent upstream. A client *factory* lets reconnect/resume tests hand out +fresh sockets per connection. +""" + +import asyncio + +import pytest +from injectq import InjectQ + +from agentflow.core.realtime.base import ( + AudioDeltaEvent, + GoAwayEvent, + InputTranscriptEvent, + InterruptedEvent, + OutputTranscriptEvent, + RealtimeConfig, + SessionUpdateEvent, + ToolCallEvent, + ToolResultEvent, + TurnCompleteEvent, +) +from agentflow.core.realtime.live_agent import LiveAgent +from agentflow.core.realtime.queue import LiveInputQueue +from agentflow.core.graph.tool_node import ToolNode +from agentflow.core.state import AgentState, Message +from agentflow.runtime.publisher.base_publisher import BasePublisher +from agentflow.utils import CallbackManager +from agentflow.utils.background_task_manager import BackgroundTaskManager + +MODEL = "gemini-2.5-flash-live" + + +class FakeRealtimeClient: + def __init__(self, events=None): + self.events = events or [] + self.connected_with: list[str | None] = [] + self.connected_config = None + self.sent_audio: list[tuple[bytes, int]] = [] + self.sent_text: list[str] = [] + self.activity: list[str] = [] + self.tool_responses: list[tuple[str, str, object]] = [] + self.reseeded = None + self.closed = False + + async def connect(self, config, resume_handle=None): + self.connected_with.append(resume_handle) + self.connected_config = config + + async def send_audio(self, pcm, sample_rate): + self.sent_audio.append((pcm, sample_rate)) + + async def send_text(self, text): + self.sent_text.append(text) + + async def send_activity_start(self): + self.activity.append("start") + + async def send_activity_end(self): + self.activity.append("end") + + async def send_tool_response(self, call_id, name, result): + self.tool_responses.append((call_id, name, result)) + + async def reseed_history(self, messages): + self.reseeded = list(messages) + + async def receive(self): + for event in self.events: + yield event + + async def close(self): + self.closed = True + + +def _factory(*clients): + """Return a factory that hands out the given clients in order.""" + seq = list(clients) + + def make(): + return seq.pop(0) + + return make + + +def _closed_queue(): + q = LiveInputQueue() + q.close() # pump exits immediately; receive script drives the test + return q + + +async def _drain(agent, queue, config, **kw): + return [event async for event in agent.arun(queue, config, **kw)] + + +class TestForcingRule: + @pytest.mark.asyncio + async def test_execute_raises_directing_to_arealtime(self): + agent = LiveAgent(MODEL, realtime_client_factory=_factory(FakeRealtimeClient())) + with pytest.raises(RuntimeError, match="arealtime"): + await agent.execute(AgentState(), {}) + + def test_non_google_model_rejected(self): + with pytest.raises(ValueError, match="google"): + LiveAgent("gpt-4o-realtime") + + +class TestDuplexLoop: + @pytest.mark.asyncio + async def test_yields_audio_and_turn_events_to_caller(self): + client = FakeRealtimeClient([AudioDeltaEvent(data=b"\x01"), TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert [e.type for e in events] == ["audio_delta", "turn_complete"] + assert client.closed is True + + @pytest.mark.asyncio + async def test_pump_maps_queue_frames_to_provider(self): + client = FakeRealtimeClient() + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + agent._active_client = client + q = LiveInputQueue() + q.send_audio(b"x", sample_rate=16000) + q.send_text("hi") + q.send_activity_start() + q.send_activity_end() + q.close() + + await agent._pump(q) + + assert client.sent_audio == [(b"x", 16000)] + assert client.sent_text == ["hi"] + assert client.activity == ["start", "end"] + + +class TestToolLoop: + @pytest.mark.asyncio + async def test_tool_call_invokes_toolnode_and_sends_response(self): + def get_weather(city: str) -> str: + return f"sunny in {city}" + + client = FakeRealtimeClient( + [ToolCallEvent(id="c1", name="get_weather", args={"city": "Paris"})] + ) + agent = LiveAgent( + MODEL, + tool_node=ToolNode([get_weather]), + realtime_client_factory=_factory(client), + ) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert client.tool_responses[0][0] == "c1" + assert client.tool_responses[0][1] == "get_weather" + assert client.tool_responses[0][2] == {"result": "sunny in Paris"} + # caller sees both the tool_call and a synthesized tool_result + assert any(isinstance(e, ToolCallEvent) for e in events) + assert any( + isinstance(e, ToolResultEvent) and e.result == {"result": "sunny in Paris"} + for e in events + ) + + +class TestTransparency: + @pytest.mark.asyncio + async def test_tool_loop_and_bargein_emit_publisher_events(self): + class SpyPublisher(BasePublisher): + def __init__(self): + self.events = [] + + async def publish(self, event): + self.events.append(event) + + async def close(self): + pass + + def sync_close(self): + pass + + spy = SpyPublisher() + InjectQ.get_instance().bind_instance(BasePublisher, spy, allow_concrete=True) + + def ping() -> str: + return "pong" + + client = FakeRealtimeClient( + [ToolCallEvent(id="c1", name="ping", args={}), InterruptedEvent()] + ) + agent = LiveAgent(MODEL, tool_node=ToolNode([ping]), realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1", "run_id": "r1"}) + + tm = InjectQ.get_instance().try_get(BackgroundTaskManager) + if tm is not None: + await tm.wait_for_all(timeout=2.0) + + kinds = {(str(e.event), str(e.event_type)) for e in spy.events} + # ToolNode publishes tool execution; LiveAgent publishes barge-in. + assert any(ev == "tool_execution" for ev, _ in kinds) + assert any(etype == "interrupted" for _, etype in kinds) + + +class TestBargeIn: + @pytest.mark.asyncio + async def test_interruption_propagates_and_pump_survives(self): + client = FakeRealtimeClient( + [AudioDeltaEvent(data=b"\x01"), InterruptedEvent(), AudioDeltaEvent(data=b"\x02")] + ) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + q = LiveInputQueue() # left open: pump must stay alive across the interruption + + events = [] + async for event in agent.arun(q, {"thread_id": "t1"}): + events.append(event) + if event.type == "interrupted": + q.send_audio(b"\x03", sample_rate=16000) # input still accepted mid-session + await asyncio.sleep(0.02) # let the still-alive pump task drain it + + assert [e.type for e in events] == ["audio_delta", "interrupted", "audio_delta"] + assert (b"\x03", 16000) in client.sent_audio + + +class TestTranscriptPersistence: + @pytest.mark.asyncio + async def test_finished_transcripts_persist_as_messages_no_audio(self): + from agentflow.storage.checkpointer import InMemoryCheckpointer + + client = FakeRealtimeClient( + [ + AudioDeltaEvent(data=b"\xaa\xbb"), # must NOT be persisted + InputTranscriptEvent(text="hello", finished=True), + OutputTranscriptEvent(text="hi there", finished=True), + InputTranscriptEvent(text="partial", finished=False), # must NOT persist + ] + ) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + cp = InMemoryCheckpointer() + state = AgentState() + config = {"thread_id": "t-persist", "user_id": "u1"} + + await _drain(agent, _closed_queue(), config, state=state, checkpointer=cp) + + roles = [(m.role, m.content[0].text) for m in state.context] + assert ("user", "hello") in roles + assert ("assistant", "hi there") in roles + assert all(text != "partial" for _, text in roles) + # audio bytes never become messages + for m in state.context: + assert m.metadata.get("modality") == "audio" + persisted = await cp.alist_messages(config) + assert {m.content[0].text for m in persisted} == {"hello", "hi there"} + + +class TestSessionConfig: + @pytest.mark.asyncio + async def test_per_session_overrides_merge_over_realtime_config(self): + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain( + agent, + _closed_queue(), + {"thread_id": "t1", "realtime": {"voice": "Puck", "response_modalities": ["TEXT"]}}, + ) + + # Per-session overrides win; unspecified fields keep the agent's base config. + assert client.connected_config.voice == "Puck" + assert client.connected_config.response_modalities == ["TEXT"] + assert client.connected_config.model == MODEL + + @pytest.mark.asyncio + async def test_no_overrides_uses_base_config_identity(self): + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert client.connected_config is agent.realtime_config + + +class TestToolAdvertising: + @pytest.mark.asyncio + async def test_tool_node_tools_advertised_to_provider(self): + def get_weather(city: str) -> str: + """Get the weather.""" + return "sunny" + + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent( + MODEL, tool_node=ToolNode([get_weather]), realtime_client_factory=_factory(client) + ) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + names = [t["function"]["name"] for t in client.connected_config.tools] + assert "get_weather" in names + + @pytest.mark.asyncio + async def test_explicit_config_tools_take_precedence(self): + def unused(x: int) -> int: + """Unused.""" + return x + + client = FakeRealtimeClient([TurnCompleteEvent()]) + cfg = RealtimeConfig(model=MODEL, tools=[{"sentinel": True}]) + agent = LiveAgent( + MODEL, + realtime_config=cfg, + tool_node=ToolNode([unused]), + realtime_client_factory=_factory(client), + ) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert client.connected_config.tools == [{"sentinel": True}] + + @pytest.mark.asyncio + async def test_advertised_tools_filtered_by_tools_tags(self): + from agentflow.utils import tool + + @tool(tags=["weather"]) + def get_weather(city: str) -> str: + """Weather.""" + return "x" + + @tool(tags=["math"]) + def add(a: int, b: int) -> int: + """Add.""" + return a + b + + client = FakeRealtimeClient([TurnCompleteEvent()]) + cfg = RealtimeConfig(model=MODEL, tools_tags=["weather"]) + agent = LiveAgent( + MODEL, + realtime_config=cfg, + tool_node=ToolNode([get_weather, add]), + realtime_client_factory=_factory(client), + ) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + names = [t["function"]["name"] for t in client.connected_config.tools] + assert names == ["get_weather"] + + @pytest.mark.asyncio + async def test_no_tool_node_leaves_tools_unset(self): + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert client.connected_config.tools is None + + +class TestClientHangup: + @pytest.mark.asyncio + async def test_closing_queue_ends_idle_session(self): + # Provider yields one event then goes idle (blocks forever); closing the input + # queue must end the session rather than hang on receive(). + class IdleClient(FakeRealtimeClient): + async def receive(self): + yield AudioDeltaEvent(data=b"\x01") + await asyncio.Event().wait() # never resolves + + agent = LiveAgent(MODEL, realtime_client_factory=_factory(IdleClient())) + q = LiveInputQueue() # left open + + events = [] + + async def run(): + async for event in agent.arun(q, {"thread_id": "t1"}): + events.append(event) + + task = asyncio.create_task(run()) + await asyncio.sleep(0.05) # let the first event flow and the provider go idle + q.close() + await asyncio.wait_for(task, timeout=1.0) + + assert [e.type for e in events] == ["audio_delta"] + + @pytest.mark.asyncio + async def test_closed_queue_still_drains_available_events(self): + # A pre-closed queue must not preempt the provider's already-available events. + client = FakeRealtimeClient([AudioDeltaEvent(data=b"\x01"), TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert [e.type for e in events] == ["audio_delta", "turn_complete"] + + +class TestResumption: + @pytest.mark.asyncio + async def test_session_update_caches_and_persists_handle(self): + from agentflow.storage.checkpointer import InMemoryCheckpointer + + client = FakeRealtimeClient([SessionUpdateEvent(resumption_handle="H1")]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + cp = InMemoryCheckpointer() + config = {"thread_id": "t-resume", "user_id": "u1"} + + await _drain(agent, _closed_queue(), config, checkpointer=cp) + + assert agent._resume_handle == "H1" + thread = await cp.aget_thread(config) + assert thread.metadata["resumption_handle"] == "H1" + + @pytest.mark.asyncio + async def test_go_away_reconnects_with_stored_handle(self): + first = FakeRealtimeClient([SessionUpdateEvent(resumption_handle="H1"), GoAwayEvent(time_left="2s")]) + second = FakeRealtimeClient([AudioDeltaEvent(data=b"\x09"), TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(first, second)) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + # second socket opened with the cached handle, stream continued seamlessly + assert second.connected_with == ["H1"] + assert [e.type for e in events] == ["session_update", "go_away", "audio_delta", "turn_complete"] + assert first.closed and second.closed + + @pytest.mark.asyncio + async def test_go_away_without_handle_reconnects_fresh(self): + # go_away before any session_update: must still reconnect (fresh, no handle) + # instead of terminating the session. + first = FakeRealtimeClient([GoAwayEvent(time_left="1s")]) + second = FakeRealtimeClient([AudioDeltaEvent(data=b"\x07"), TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(first, second)) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert second.connected_with == [None] + assert [e.type for e in events] == ["go_away", "audio_delta", "turn_complete"] + assert first.closed and second.closed + + @pytest.mark.asyncio + async def test_cross_session_reseeds_history(self): + from agentflow.storage.checkpointer import InMemoryCheckpointer + + cp = InMemoryCheckpointer() + config = {"thread_id": "t-cross", "user_id": "u1"} + await cp.aput_messages( + config, + [Message.text_message("earlier turn", role="user")], + ) + + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), config, checkpointer=cp) + + assert client.reseeded is not None + assert client.reseeded[0].content[0].text == "earlier turn" diff --git a/tests/realtime/test_package_exports.py b/tests/realtime/test_package_exports.py new file mode 100644 index 00000000..8e4b73b9 --- /dev/null +++ b/tests/realtime/test_package_exports.py @@ -0,0 +1,38 @@ +"""The realtime package must re-export the Phase 1 public surface (SDK-first usage).""" + +import agentflow.core.realtime as rt + + +def test_public_symbols_are_exported(): + expected = { + "RealtimeClient", + "RealtimeConfig", + "RealtimeEvent", + "VADConfig", + "AudioDeltaEvent", + "InputTranscriptEvent", + "OutputTranscriptEvent", + "ToolCallEvent", + "ToolResultEvent", + "TurnCompleteEvent", + "InterruptedEvent", + "SessionUpdateEvent", + "GoAwayEvent", + "AgentChangedEvent", + "ErrorEvent", + "LiveInputQueue", + "LiveInput", + "GeminiLiveClient", + "normalize_message", + } + assert expected.issubset(set(rt.__all__)) + for name in expected: + assert hasattr(rt, name), f"{name} missing from agentflow.core.realtime" + + +def test_queue_constructs_from_top_level_import(): + q = rt.LiveInputQueue() + q.send_text("hi") + item = q.get_nowait() + assert isinstance(item, rt.LiveInput) + assert item.kind == "text" diff --git a/tests/realtime/test_queue.py b/tests/realtime/test_queue.py new file mode 100644 index 00000000..3ca238bd --- /dev/null +++ b/tests/realtime/test_queue.py @@ -0,0 +1,71 @@ +"""Unit tests for the upstream decoupler (agentflow.core.realtime.queue).""" + +import asyncio + +import pytest + +from agentflow.core.realtime.queue import LiveInput, LiveInputQueue + + +class TestLiveInputQueuePut: + def test_send_audio_is_nonblocking_and_enqueues_audio_frame(self): + q = LiveInputQueue() + q.send_audio(b"\x00\x01", sample_rate=16000) + item = q.get_nowait() + assert isinstance(item, LiveInput) + assert item.kind == "audio" + assert item.data == b"\x00\x01" + assert item.sample_rate == 16000 + + def test_send_text_enqueues_text_frame(self): + q = LiveInputQueue() + q.send_text("hello") + item = q.get_nowait() + assert item.kind == "text" + assert item.text == "hello" + + def test_activity_markers_enqueue_control_frames(self): + q = LiveInputQueue() + q.send_activity_start() + q.send_activity_end() + assert q.get_nowait().kind == "activity_start" + assert q.get_nowait().kind == "activity_end" + + def test_close_enqueues_sentinel_and_marks_closed(self): + q = LiveInputQueue() + q.close() + assert q.get_nowait().kind == "close" + assert q.closed is True + + def test_put_after_close_is_dropped(self): + q = LiveInputQueue() + q.close() + q.get_nowait() # drain the close sentinel + q.send_audio(b"\x00", sample_rate=16000) # should be a no-op, not raise + with pytest.raises(asyncio.QueueEmpty): + q.get_nowait() + + +class TestLiveInputQueueConsume: + @pytest.mark.asyncio + async def test_get_awaits_until_item_available(self): + q = LiveInputQueue() + + async def producer(): + await asyncio.sleep(0.01) + q.send_text("late") + + asyncio.create_task(producer()) + item = await q.get() + assert item.kind == "text" + assert item.text == "late" + + @pytest.mark.asyncio + async def test_async_iteration_stops_at_close(self): + q = LiveInputQueue() + q.send_text("a") + q.send_audio(b"\x01", sample_rate=16000) + q.close() + + seen = [item async for item in q] + assert [i.kind for i in seen] == ["text", "audio"] diff --git a/tests/realtime/test_realtime_events_enum.py b/tests/realtime/test_realtime_events_enum.py new file mode 100644 index 00000000..72296696 --- /dev/null +++ b/tests/realtime/test_realtime_events_enum.py @@ -0,0 +1,19 @@ +"""Phase 3: the publisher taxonomy gains a REALTIME category + transcript content type +so all publisher backends inherit realtime telemetry without per-backend changes.""" + +from agentflow.runtime.publisher.events import ContentType, Event, EventType + + +def test_realtime_event_category_added(): + assert Event.REALTIME.value == "realtime" + + +def test_transcript_content_type_added(): + assert ContentType.TRANSCRIPT.value == "transcript" + + +def test_existing_members_unchanged(): + # Regression guard: realtime additions must not perturb the existing taxonomy. + assert Event.TOOL_EXECUTION.value == "tool_execution" + assert EventType.INTERRUPTED.value == "interrupted" + assert ContentType.AUDIO.value == "audio" From 28a6ac53f3014ab874dafe5c1a212a6ea89963d9 Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Mon, 15 Jun 2026 18:27:29 +0600 Subject: [PATCH 2/3] Enhance Gemini Live Client and Examples - Introduced a new `_transcript_event` function to streamline transcript event creation in `gemini_live.py`. - Updated `normalize_message` to utilize the new transcript event function, ensuring proper handling of finished markers. - Refactored `GeminiLiveClient` to support configurable project and location for Vertex AI, improving client initialization. - Added robust error handling for missing credentials and project configurations in the client. - Enhanced examples for live microphone and file input, improving user experience and documentation clarity. - Updated README files to reflect new features and usage instructions for the microphone example. - Improved test coverage for transcript handling, ensuring proper accumulation and flushing of partial transcripts. - Added tests for client credential resolution and session management, ensuring robust functionality. - Included documentation for production readiness, addressing type checking, timeout configurations, and logging enhancements. --- Need_TO_INCLUDE_IN_DOCS.md | 56 ++++++ Plan.md | 70 +++++++ agentflow/core/realtime/base.py | 6 +- agentflow/core/realtime/live_agent.py | 171 ++++++++++++++++-- .../core/realtime/providers/gemini_live.py | 123 ++++++++++--- examples/realtime/README.md | 7 +- examples/realtime/audio_agent_file.py | 7 +- examples/realtime/audio_agent_mic.py | 113 +++++++++--- examples/realtime/graph.py | 7 +- tests/realtime/test_gemini_live.py | 126 ++++++++++++- tests/realtime/test_live_agent.py | 126 +++++++++++++ 11 files changed, 735 insertions(+), 77 deletions(-) create mode 100644 Need_TO_INCLUDE_IN_DOCS.md create mode 100644 Plan.md diff --git a/Need_TO_INCLUDE_IN_DOCS.md b/Need_TO_INCLUDE_IN_DOCS.md new file mode 100644 index 00000000..1cdd45b4 --- /dev/null +++ b/Need_TO_INCLUDE_IN_DOCS.md @@ -0,0 +1,56 @@ +# Need to Include in Docs + +Production-readiness work completed in this pass. Items below introduce or change +user-facing behavior and should be documented. + +## Type checking (PEP 561) +- The package now ships a `py.typed` marker. Downstream `mypy`/`pyright` will + type-check against Agentflow's annotations. + +## Configurable LLM call timeout +- All LLM clients now apply a default request timeout (`600s`) so a stalled + provider cannot hang a run indefinitely. +- Override globally via the `AGENTFLOW_LLM_TIMEOUT` environment variable + (seconds), or programmatically: + - `from agentflow.core.llm import set_default_llm_timeout, get_default_llm_timeout, DEFAULT_LLM_TIMEOUT_SECONDS` + - `set_default_llm_timeout(120.0)` / `set_default_llm_timeout(None)` to reset. +- An explicit per-client `timeout=` still takes precedence. + +## CompiledGraph async context manager +- `CompiledGraph` supports `async with`: + ```python + async with await build_and_compile_graph() as graph: + await graph.ainvoke(input_data) + # aclose() runs automatically on exit, even if the body raises + ``` +- `aclose()` is now idempotent (second call returns `{"status": "already_closed"}`). + +## Circuit breaker for LLM calls (opt-in) +- Complements retry + `fallback_models`: once a `(provider, model)` fails + `circuit_breaker_threshold` times in a row, its circuit opens and that target + is skipped (straight to the next fallback) for `circuit_breaker_reset_timeout` + seconds, instead of being retried on every call. +- Configure via `RetryConfig`: + - `circuit_breaker_enabled: bool = False` + - `circuit_breaker_threshold: int = 5` + - `circuit_breaker_reset_timeout: float = 30.0` + +## Secret redaction for logs +- New helpers in `agentflow.utils`: + - `mask_secrets(text)` — redacts API keys, `Bearer` tokens, `key=value` + secrets, and signed-URL credential query params. + - `SecretRedactionFilter` — a `logging.Filter`; add it to a handler to cover + all loggers that propagate to it. + - `install_secret_redaction(logger_name="agentflow")` — convenience installer. + +## ConsolePublisher logging option +- `ConsolePublisher` is a dev/debug, opt-in publisher (use a real transport in + production). It writes to stdout by default; pass `{"use_logger": True}` to + route events through the `agentflow.publisher` logger instead of stdout. + +## Project / repo +- Dependencies now have version bounds (e.g. `pydantic>=2,<3`). +- mypy runs in pre-commit/CI (phased adoption; see `CONTRIBUTING.md`). +- Test coverage gate raised to 80%. +- Added `SECURITY.md` and `CONTRIBUTING.md`. +- Added Dependabot config and a CodeQL workflow. diff --git a/Plan.md b/Plan.md new file mode 100644 index 00000000..7df3dc08 --- /dev/null +++ b/Plan.md @@ -0,0 +1,70 @@ +Agentflow Core Python SDK — Production Readiness Review +Scope: the 10xscale-agentflow package (v0.7.5.1) in agentflow/. Overall this is a mature, well-structured framework with strong CI/CD release automation, a comprehensive test suite (138 files), a clean exception hierarchy, and genuinely good OTEL support. The gaps below are what stand between it and "production-grade SDK that other teams depend on." + +Blockers (fix before claiming production-stable) +1. No py.typed marker — the package ships untyped to consumers. +Confirmed missing: no agentflow/agentflow/py.typed and no package-data rule in pyproject.toml. Despite extensive internal type hints, PEP 561 means downstream mypy/pyright see Any for every Agentflow symbol. For a "Production/Stable"-classified library this is the single highest-leverage fix: add the empty marker plus a [tool.setuptools.package-data] entry. + +2. Core dependencies are unpinned. +pyproject.toml:51,64-66 lists pydantic, PyYAML, python-dotenv, and pydantic-ai with no version bounds at all. A Pydantic v2→v3 or breaking pydantic-ai release will silently break installs in the field. At minimum set lower bounds (pydantic>=2,<3). The uv.lock protects this repo's own builds but does nothing for pip install 10xscale-agentflow users. + +3. README claims native Anthropic/Claude support that does not exist. +README headline (lines ~17, 24, 190) advertises native Anthropic and ANTHROPIC_API_KEY, but client_factory.py detect_provider only resolves google or openai. A user running detect_provider("claude-3-opus") gets openai and a construction failure. Either remove the claim or document it as "via OpenAI-compatible endpoint only." (Note: the README import-path drift that CLAUDE.md warns about appears to have been fixed — examples now use correct agentflow.core.* paths. CLAUDE.md's "broken examples" note is itself stale.) + +High priority +4. No default timeout on LLM calls. +client_factory.py:34 accepts a timeout kwarg but enforces no default, and there is no top-level timeout wrapping invoke/ainvoke. A hung provider connection blocks indefinitely. Add a sane default client timeout and a per-request ceiling. + +5. mypy is configured but never runs. +pyproject.toml has [tool.mypy], but it is absent from both .pre-commit-config.yaml and .github/workflows/ci.yml (confirmed: zero mypy references in either). It is dead config. Either wire it into CI or stop advertising type safety. This pairs directly with #1. + +6. CI tests a single Python version on a single OS. +ci.yml runs only Python 3.13 on ubuntu-latest, yet the package claims >=3.12 and classifies 3.12/3.13. 3.12 is untested. Add a 3.12/3.13 matrix; consider macOS. + +Medium priority +7. Silent exception swallowing in media + callback paths. Broad except Exception with debug-only logging in media_resolver.py (e.g. lines 100, 193, 234) and throughout callbacks.py hides real failures in production. Narrow these or at least log at warning with context. + +8. No __aenter__/__aexit__ on CompiledGraph. Cleanup relies on callers remembering aclose(). Publisher backends (Kafka/RabbitMQ/Redis) may leak connections if shutdown raises. Add the async-context-manager protocol to the top-level graph. + +9. Missing governance/policy files. No SECURITY.md (no vuln disclosure path) and no CONTRIBUTING.md. Both are table-stakes for an OSS framework inviting external dependence. + +10. No migration guide or deprecation policy. The 0.5→0.7 flat→nested import restructure (agentflow.state → agentflow.core.state) has no MIGRATION.md, and there's no structured deprecation mechanism. changelogs.md is unstructured (no dates, no Keep-a-Changelog format). For a stable API, document the deprecation contract. + +11. ConsolePublisher uses print(). console_publisher.py:61 writes to stdout. Acceptable for a console publisher by design, but worth confirming it never runs by default in a server context. + +Lower priority / nice-to-have +No circuit breaker to complement fallback_models — repeated failures to a dead provider retry every call. +No secret masking in logs; API keys/auth headers could surface in debug output. +No Dependabot/CodeQL beyond Bandit. +Coverage gate at 70% is low for a production framework; 80%+ is a better bar. +normal_tests/ (12 files) excluded from discovery with no documented rationale. +What's already solid (don't regress) +Release automation: tag-gated, OIDC trusted PyPI publishing with version verification (release.yml). +Exception hierarchy with error_code/context/to_dict() and an explicit TransientStorageError for retryability. +OTEL tracing with gen_ai semantic conventions and three observability levels. +Proper library logging (getLogger(__name__) + NullHandler, no stdout pollution in the core path). +Parallel tool execution genuinely works (invoke_node_handler.py:182). +Comprehensive __all__ exports across all subpackages. +Suggested order of attack +py.typed + dependency pinning + fix the Anthropic README claim (an afternoon; ships in next patch). +LLM timeouts + wire mypy into CI + 3.12 matrix. +SECURITY.md, CONTRIBUTING.md, MIGRATION.md. +Tighten exception handling and add the graph async-context-manager protocol. +Items 1-3 are mostly mechanical and would move the package materially toward production-grade. Want me to implement the quick wins in #1 (add py.typed, set dependency bounds, correct the README)? + + + + +4. lets add keep a default, using env or globally we can change it +5. lets add +6. fine for now, will focus on later +7. lets fix it +8. lets add +9. lets add +10. skip for now +11. add it + +No circuit breaker to complement fallback_models — repeated failures to a dead provider retry every call. +No secret masking in logs; API keys/auth headers could surface in debug output. +No Dependabot/CodeQL beyond Bandit. +Coverage gate at 70% is low for a production framework; 80%+ is a better bar. \ No newline at end of file diff --git a/agentflow/core/realtime/base.py b/agentflow/core/realtime/base.py index 859a7152..c531ed1b 100644 --- a/agentflow/core/realtime/base.py +++ b/agentflow/core/realtime/base.py @@ -8,7 +8,7 @@ from typing import Annotated, Any, Literal, Protocol, Union, runtime_checkable -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator # Audio format facts for Gemini Live: input PCM16 mono @ 16kHz, output PCM16 @ 24kHz. @@ -146,6 +146,10 @@ class RealtimeConfig(BaseModel): is validated to a single entry. """ + # validate_default so the default modality list is held to the same one-modality rule + # as explicit values (otherwise a bad default silently bypasses the validator below). + model_config = ConfigDict(validate_default=True) + model: str response_modalities: list[ResponseModality] = Field(default_factory=lambda: ["AUDIO"]) voice: str | None = None diff --git a/agentflow/core/realtime/live_agent.py b/agentflow/core/realtime/live_agent.py index d0040512..4e59ec0a 100644 --- a/agentflow/core/realtime/live_agent.py +++ b/agentflow/core/realtime/live_agent.py @@ -29,7 +29,14 @@ from agentflow.core.graph.base_agent import BaseAgent from agentflow.core.graph.tool_node import ToolNode from agentflow.core.llm import detect_provider -from agentflow.core.realtime.base import RealtimeClient, RealtimeConfig, ToolResultEvent +from agentflow.core.realtime.base import ( + ErrorEvent, + InputTranscriptEvent, + OutputTranscriptEvent, + RealtimeClient, + RealtimeConfig, + ToolResultEvent, +) from agentflow.core.realtime.providers.gemini_live import GeminiLiveClient from agentflow.core.state import AgentState, Message, TextBlock, add_messages from agentflow.runtime.publisher.events import ContentType, Event, EventModel, EventType @@ -63,6 +70,8 @@ def __init__( ) -> None: api_key: str | None = kwargs.pop("api_key", None) use_vertex_ai: bool = kwargs.pop("use_vertex_ai", False) + project: str | None = kwargs.pop("project", None) + location: str | None = kwargs.pop("location", None) provider = detect_provider(model, use_vertex_ai) if provider != "google": @@ -89,7 +98,12 @@ def __init__( # One client per *connection*; the factory lets reconnects get a fresh socket # and lets tests inject a fake provider. self._client_factory: Callable[[], RealtimeClient] = realtime_client_factory or ( - lambda: GeminiLiveClient(api_key=api_key, use_vertex_ai=use_vertex_ai) + lambda: GeminiLiveClient( + api_key=api_key, + use_vertex_ai=use_vertex_ai, + project=project, + location=location, + ) ) self._active_client: RealtimeClient | None = None self._resume_handle: str | None = None @@ -97,6 +111,17 @@ def __init__( # sends on a socket being torn down, and always picks up the reconnected client. self._send_lock = asyncio.Lock() + # Per-session transcript accumulators (provider streams partial chunks; we flush + # the concatenation on the finished marker). Reset at the start of each arun(). + self._input_transcript_buf = "" + self._output_transcript_buf = "" + + # Error-driven reconnect backoff (go_away reconnects are immediate; only transient + # drops back off). Instance attributes so tests can shrink them. + self._reconnect_base_delay = 0.5 + self._reconnect_max_delay = 10.0 + self._reconnect_max_attempts = 5 + # Builder mixins (no-op when their config is None). self._setup_memory(memory) self._setup_skills(skills) @@ -138,6 +163,8 @@ async def arun( state = state if state is not None else AgentState() if callback_manager is None: callback_manager = CallbackManager() + self._input_transcript_buf = "" + self._output_transcript_buf = "" rt = self._session_realtime_config(config) rt = await self._resolve_session_tools(rt) @@ -145,18 +172,25 @@ async def arun( client = self._client_factory() await client.connect(rt, resume_handle=handle) self._active_client = client - await self._maybe_reseed(config, checkpointer, context_manager) + # Only reseed when the provider did NOT restore context from a handle; otherwise the + # model would receive the whole conversation twice (handle restore + reseed). + await self._maybe_reseed( + config, checkpointer, context_manager, resumed=handle is not None + ) # Closing the input queue ends the session: the pump sets this when it drains the # close sentinel, and the receive loop stops once the provider goes idle. stop_event = asyncio.Event() pump_task = asyncio.create_task(self._pump(input_queue, stop_event)) + attempts = 0 # consecutive error-driven reconnect attempts (reset on healthy receive) try: while True: reconnect = False forced = False # go_away: reconnect even after input closed, to finish the turn + received_any = False try: async for event in self._receive_until_stop(self._active_client, stop_event): + received_any = True for out in await self._handle_event( event, config, state, checkpointer, callback_manager ): @@ -173,10 +207,18 @@ async def arun( logger.warning("realtime receive loop error; attempting resume", exc_info=True) reconnect = True - if reconnect and rt.session_resumption and (forced or not stop_event.is_set()): - await self._reconnect(rt) - continue - break + # A turn that produced events means the connection is healthy again. + if received_any: + attempts = 0 + + resumable = reconnect and rt.session_resumption + if not (resumable and (forced or not stop_event.is_set())): + break + + attempts, fatal = await self._attempt_reconnect(rt, forced, attempts) + if fatal is not None: + yield fatal + break finally: pump_task.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -295,17 +337,18 @@ async def _handle_event( if kind == "tool_call": return await self._run_tool(event, config, state, callback_manager) - if kind == "input_transcript" and event.finished: - await self._persist_transcript(event.text, "user", config, state, checkpointer) - self._publish_realtime( - EventType.RESULT, config, ContentType.TRANSCRIPT, "input_transcript" - ) - elif kind == "output_transcript" and event.finished: - await self._persist_transcript(event.text, "assistant", config, state, checkpointer) - self._publish_realtime( - EventType.RESULT, config, ContentType.TRANSCRIPT, "output_transcript" + if kind == "input_transcript": + return await self._accumulate_transcript(event, "user", config, state, checkpointer) + if kind == "output_transcript": + return await self._accumulate_transcript( + event, "assistant", config, state, checkpointer ) - elif kind == "interrupted": + + if kind == "interrupted": + # Barge-in discards the model's in-flight turn; drop any partial transcript so + # the restarted turn isn't concatenated onto the abandoned one. + self._input_transcript_buf = "" + self._output_transcript_buf = "" self._publish_realtime(EventType.INTERRUPTED, config, ContentType.AUDIO, "barge_in") elif kind == "go_away": self._publish_realtime(EventType.UPDATE, config, ContentType.UPDATE, "go_away") @@ -316,6 +359,47 @@ async def _handle_event( return [event] + async def _accumulate_transcript( + self, + event: RealtimeEvent, + role: str, + config: dict[str, Any], + state: AgentState, + checkpointer: BaseCheckpointer | None, + ) -> list[RealtimeEvent]: + """Accumulate streamed transcript chunks and consolidate on the finished marker. + + Providers stream transcripts as partial chunks (``finished=False``) and end with a + finished marker that usually carries no text. Partials pass through unchanged for live + display; on ``finished`` we persist the full turn and emit a single consolidated + event carrying the complete text (so consumers get the whole transcript without + having to accumulate themselves). Turns with no transcribed text emit nothing. + """ + is_user = role == "user" + if is_user: + self._input_transcript_buf += event.text + buffered = self._input_transcript_buf + else: + self._output_transcript_buf += event.text + buffered = self._output_transcript_buf + + if not event.finished: + return [event] # stream the partial for live UIs + + full = buffered.strip() + if is_user: + self._input_transcript_buf = "" + else: + self._output_transcript_buf = "" + if not full: + return [] # nothing transcribed this turn; drop the empty finish marker + + await self._persist_transcript(full, role, config, state, checkpointer) + lifecycle = "input_transcript" if is_user else "output_transcript" + self._publish_realtime(EventType.RESULT, config, ContentType.TRANSCRIPT, lifecycle) + event_cls = InputTranscriptEvent if is_user else OutputTranscriptEvent + return [event_cls(text=full, finished=True)] + async def _run_tool( self, event: RealtimeEvent, @@ -337,8 +421,13 @@ async def _run_tool( ) result = self._extract_tool_result(invoked) - # Socket stays open; feed the result back to the model. - await self._active_client.send_tool_response(event.id, event.name, result) + # Socket stays open; feed the result back to the model. Hold _send_lock (and re-read + # the live client inside it) so this send is serialized against the pump and any + # concurrent reconnect, exactly like the pump's own sends. + async with self._send_lock: + client = self._active_client + if client is not None: + await client.send_tool_response(event.id, event.name, result) return [event, ToolResultEvent(id=event.id, result=result)] @staticmethod @@ -411,8 +500,12 @@ async def _maybe_reseed( config: dict[str, Any], checkpointer: BaseCheckpointer | None, context_manager: BaseContextManager | None, + *, + resumed: bool = False, ) -> None: - if checkpointer is None: + # When the provider restored context from a resumption handle, reseeding would + # replay the whole conversation a second time. + if checkpointer is None or resumed: return try: history = await checkpointer.alist_messages(config) @@ -429,6 +522,44 @@ async def _maybe_reseed( if history: await self._active_client.reseed_history(list(history)) + async def _attempt_reconnect( + self, rt: RealtimeConfig, forced: bool, attempts: int + ) -> tuple[int, ErrorEvent | None]: + """Reconnect after a drop. Returns ``(attempts, fatal_error_or_None)``. + + ``forced`` (go_away) is an expected provider rotation: reconnect promptly, no backoff. + Error-driven drops back off exponentially with a hard attempt cap so a flapping or + down provider can't spin a tight reconnect storm; once the cap is hit a fatal + :class:`ErrorEvent` is returned for the caller to surface before ending the session. + """ + if forced: + with contextlib.suppress(Exception): + await self._reconnect(rt) + return 0, None + + attempts += 1 + if attempts > self._reconnect_max_attempts: + logger.error( + "realtime reconnect attempts exhausted (%d); giving up", + self._reconnect_max_attempts, + ) + return attempts, ErrorEvent( + code="reconnect_failed", + message=( + "realtime session lost and could not be resumed after " + f"{self._reconnect_max_attempts} attempts" + ), + fatal=True, + ) + + delay = min( + self._reconnect_base_delay * (2 ** (attempts - 1)), self._reconnect_max_delay + ) + await asyncio.sleep(delay) + with contextlib.suppress(Exception): + await self._reconnect(rt) + return attempts, None + async def _reconnect(self, rt: RealtimeConfig) -> None: async with self._send_lock: old = self._active_client diff --git a/agentflow/core/realtime/providers/gemini_live.py b/agentflow/core/realtime/providers/gemini_live.py index 46ab168c..93ffdcbe 100644 --- a/agentflow/core/realtime/providers/gemini_live.py +++ b/agentflow/core/realtime/providers/gemini_live.py @@ -10,6 +10,7 @@ from __future__ import annotations import logging +import os import re from collections.abc import AsyncIterator from typing import Any @@ -44,6 +45,20 @@ def _rate_from_mime(mime_type: str | None, default: int) -> int: return int(match.group(1)) if match else default +def _transcript_event(tx: Any, event_cls: Any) -> RealtimeEvent | None: + """Build a transcript event from a provider transcription, or None when there's nothing. + + Emits on text OR finished so the finish marker (often text=None) is never dropped. + """ + if tx is None: + return None + text = getattr(tx, "text", None) + finished = bool(getattr(tx, "finished", False)) + if text is None and not finished: + return None + return event_cls(text=text or "", finished=finished) + + def normalize_message(message: Any) -> list[RealtimeEvent]: """Map a google ``LiveServerMessage`` to zero or more normalized events. @@ -62,26 +77,27 @@ def normalize_message(message: Any) -> list[RealtimeEvent]: rate = _rate_from_mime(getattr(inline, "mime_type", None), OUTPUT_SAMPLE_RATE) events.append(AudioDeltaEvent(data=data, sample_rate=rate)) - in_tx = getattr(content, "input_transcription", None) - if in_tx is not None and getattr(in_tx, "text", None) is not None: - events.append( - InputTranscriptEvent( - text=in_tx.text, finished=bool(getattr(in_tx, "finished", False)) - ) - ) - - out_tx = getattr(content, "output_transcription", None) - if out_tx is not None and getattr(out_tx, "text", None) is not None: - events.append( - OutputTranscriptEvent( - text=out_tx.text, finished=bool(getattr(out_tx, "finished", False)) - ) - ) + # Transcripts stream as partial chunks; the terminating chunk often carries + # finished=True with text=None. Emit on text OR finished so the finish marker + # is never dropped (consumers accumulate partials and flush on finished). + in_ev = _transcript_event( + getattr(content, "input_transcription", None), InputTranscriptEvent + ) + if in_ev is not None: + events.append(in_ev) + out_ev = _transcript_event( + getattr(content, "output_transcription", None), OutputTranscriptEvent + ) + if out_ev is not None: + events.append(out_ev) if getattr(content, "interrupted", None): events.append(InterruptedEvent()) - if getattr(content, "generation_complete", None) or getattr(content, "turn_complete", None): + # ``turn_complete`` is the single authoritative end-of-turn signal. ``generation_complete`` + # arrives in a separate earlier message within the same turn; mapping both to + # TurnCompleteEvent would emit two per turn and double-count turn boundaries. + if getattr(content, "turn_complete", None): events.append(TurnCompleteEvent()) tool_call = getattr(message, "tool_call", None) @@ -121,11 +137,15 @@ def __init__( connector: Any | None = None, api_key: str | None = None, use_vertex_ai: bool = False, + project: str | None = None, + location: str | None = None, ) -> None: self._client = client self._connector = connector self._api_key = api_key self._use_vertex_ai = use_vertex_ai + self._project = project + self._location = location self._config: RealtimeConfig | None = None self._cm: Any | None = None self._session: Any | None = None @@ -149,10 +169,49 @@ def _genai(): def _ensure_client(self) -> Any: if self._client is None: - genai, _ = self._genai() - self._client = genai.Client(api_key=self._api_key, vertexai=self._use_vertex_ai) + self._client = self._build_client() return self._client + def _build_client(self) -> Any: + """Construct a google-genai client, supporting both auth modes (mirrors the + turn-based factory in ``agentflow.core.llm.client_factory``). + + - Vertex AI / service account: ``use_vertex_ai=True`` with ``GOOGLE_CLOUD_PROJECT`` + (and optional ``GOOGLE_CLOUD_LOCATION``); credentials come from Application + Default Credentials (e.g. a service-account key via ``GOOGLE_APPLICATION_CREDENTIALS``). + - Developer API key: explicit ``api_key`` or ``GEMINI_API_KEY`` / ``GOOGLE_API_KEY``. + + ``vertexai`` is always passed explicitly so the ``GOOGLE_GENAI_USE_VERTEXAI`` env var + can't silently flip the mode out from under the caller. + """ + genai, _ = self._genai() + + if self._use_vertex_ai: + project = self._project or os.getenv("GOOGLE_CLOUD_PROJECT") + location = self._location or os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") + if not project: + raise ValueError( + "Vertex AI realtime requires a project: pass project=... or set " + "GOOGLE_CLOUD_PROJECT (credentials via Application Default Credentials / " + "GOOGLE_APPLICATION_CREDENTIALS)." + ) + logger.info( + "Creating Gemini Live client (Vertex AI, project=%s, location=%s)", + project, + location, + ) + return genai.Client(vertexai=True, project=project, location=location) + + api_key = self._api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + if not api_key: + raise ValueError( + "Gemini realtime requires credentials: set GEMINI_API_KEY or GOOGLE_API_KEY " + "(or pass api_key=...), or use Vertex AI with use_vertex_ai=True and " + "GOOGLE_CLOUD_PROJECT." + ) + logger.info("Creating Gemini Live client (API key)") + return genai.Client(vertexai=False, api_key=api_key) + def _get_connector(self) -> Any: if self._connector is not None: return self._connector @@ -264,21 +323,37 @@ async def reseed_history(self, messages: list[Any]) -> None: _, types = self._genai() turns = [] for message in messages: + # Gemini live turns are user/model only. System prompts are passed via + # system_instruction, and tool turns are not reseedable as dialogue, so skip + # both rather than mislabeling them as user input. + role = getattr(message, "role", "user") + if role not in ("user", "assistant"): + continue text = "".join( getattr(block, "text", "") or "" for block in getattr(message, "content", []) or [] ) if not text: continue - role = "model" if getattr(message, "role", "user") == "assistant" else "user" - turns.append(types.Content(role=role, parts=[types.Part.from_text(text=text)])) + gem_role = "model" if role == "assistant" else "user" + turns.append(types.Content(role=gem_role, parts=[types.Part.from_text(text=text)])) if turns: await session.send_client_content(turns=turns, turn_complete=True) async def receive(self) -> AsyncIterator[RealtimeEvent]: - session = self._require_session() - async for message in session.receive(): - for event in normalize_message(message): - yield event + # Gemini Live's session.receive() completes after each turn_complete; you must call + # it again for the next turn. Loop so a session spans multiple turns. A receive() + # that yields no messages means the connection is going away, so stop (and a dropped + # socket raises out of receive(), which the caller treats as a transient drop). + self._require_session() + while self._session is not None: + session = self._session + produced = False + async for message in session.receive(): + produced = True + for event in normalize_message(message): + yield event + if not produced: + break async def close(self) -> None: cm = self._cm diff --git a/examples/realtime/README.md b/examples/realtime/README.md index 7e85c14a..5ce023df 100644 --- a/examples/realtime/README.md +++ b/examples/realtime/README.md @@ -31,14 +31,15 @@ python examples/realtime/audio_agent_file.py path/to/input.wav # 16 kHz mono P # writes out.wav and prints transcripts + tool calls ``` -## 2. Live microphone (full duplex) +## 2. Live microphone (full duplex, React-style tool calling) -Speak and the agent talks back, with barge-in and tool calling. +Speak and the agent talks back out loud, with barge-in. Ask about the weather and it +calls the `get_weather` tool, then speaks the result (reason -> tool -> respond). ```bash pip install sounddevice python examples/realtime/audio_agent_mic.py -# speak; Ctrl+C to stop +# then say: "What's the weather in Tokyo?" (Ctrl+C to stop) ``` ## 3. Through the API server (`/v1/graph/live` WebSocket) diff --git a/examples/realtime/audio_agent_file.py b/examples/realtime/audio_agent_file.py index bbc8f4ba..3771750c 100644 --- a/examples/realtime/audio_agent_file.py +++ b/examples/realtime/audio_agent_file.py @@ -22,14 +22,17 @@ import sys import wave -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv from agentflow.core.realtime.base import OUTPUT_SAMPLE_RATE, RealtimeConfig from agentflow.core.realtime.queue import LiveInputQueue from agentflow.prebuilt.agent import AudioAgent -load_dotenv() +# Load .env reliably regardless of the launch directory. +_HERE = os.path.dirname(os.path.abspath(__file__)) +load_dotenv(os.path.join(_HERE, ".env")) +load_dotenv(find_dotenv(usecwd=True)) MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-live-2.5-flash-preview") diff --git a/examples/realtime/audio_agent_mic.py b/examples/realtime/audio_agent_mic.py index 1bb96c16..0b27cc92 100644 --- a/examples/realtime/audio_agent_mic.py +++ b/examples/realtime/audio_agent_mic.py @@ -1,48 +1,101 @@ -"""Live microphone audio-to-audio with AudioAgent + Gemini Live. +"""Live voice weather assistant -- AudioAgent + Gemini Live, React-style tool calling. -Speak into your microphone and the agent talks back in real time. It supports barge-in -(start talking while it is speaking and it stops to listen) and tool calls. This is the -full duplex demo; for a headless/no-hardware version see ``audio_agent_file.py``. +Talk to it through your microphone and it talks back in real time. Ask "what's the +weather in Tokyo?" and the model calls the ``get_weather`` tool, then speaks the answer +it gets back -- the realtime analog of a ReactAgent's reason -> tool -> respond loop. + +Features + - Voice playback: the model's reply is played out loud on your speakers. + - Tool calling: ``get_weather`` is advertised to the model and invoked on demand. + - Echo-safe by default: the mic is muted while the agent is speaking, so it doesn't + hear (and reply to) its own voice through your speakers. + +Echo / feedback + Without echo cancellation, your speaker audio leaks into the mic and the model + transcribes its own replies as your input. This demo avoids that by muting the mic + while the agent talks (half-duplex). Use headphones and set MIC_FULL_DUPLEX=1 for + true full duplex with barge-in (speak over the agent to interrupt it). Setup pip install "10xscale-agentflow[realtime]" sounddevice - export GEMINI_API_KEY=... - export GEMINI_LIVE_MODEL=gemini-live-2.5-flash-preview # optional, see README + export GEMINI_API_KEY=... # or Vertex AI env (see README) + export GEMINI_LIVE_MODEL=... # optional, see README Run python examples/realtime/audio_agent_mic.py - # speak; press Ctrl+C to stop. + # then say e.g. "What's the weather in Paris?" -- press Ctrl+C to stop. """ import asyncio +import contextlib import os import sys -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv from agentflow.core.realtime.base import INPUT_SAMPLE_RATE, OUTPUT_SAMPLE_RATE, RealtimeConfig from agentflow.core.realtime.queue import LiveInputQueue from agentflow.prebuilt.agent import AudioAgent -load_dotenv() - -MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-live-2.5-flash-preview") +# Load .env reliably no matter where you launch from: the one next to this script first, +# then the nearest .env walking up from the current working directory. +_HERE = os.path.dirname(os.path.abspath(__file__)) +load_dotenv(os.path.join(_HERE, ".env")) +load_dotenv(find_dotenv(usecwd=True)) + +# Use Vertex AI (service account / ADC) when GOOGLE_GENAI_USE_VERTEXAI is set; otherwise +# fall back to a Gemini API key. Both are supported by the live client. +USE_VERTEX = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").strip().lower() in ("1", "true", "yes") +# Live model names differ between the Gemini Developer API and Vertex AI. Override either +# with GEMINI_LIVE_MODEL; check Google's docs for what's enabled in your project/region. +_DEFAULT_MODEL = ( + "gemini-live-2.5-flash-preview-native-audio-09-2025" + if USE_VERTEX + else "gemini-live-2.5-flash-preview" +) +MODEL = os.getenv("GEMINI_LIVE_MODEL", _DEFAULT_MODEL) MIC_BLOCK = INPUT_SAMPLE_RATE // 10 # 100 ms frames +# Full duplex (no mic muting) -- only sensible with headphones, enables barge-in. +FULL_DUPLEX = os.getenv("MIC_FULL_DUPLEX", "").strip().lower() in ("1", "true", "yes") +# How long to keep the mic muted after the agent's turn ends, to let the speaker drain. +MUTE_TAIL_SEC = 0.4 + +# A tiny canned forecast table so different cities give different answers. Swap the body +# for a real HTTP call to a weather API and nothing else here changes. +_FORECASTS = { + "tokyo": "18 degrees Celsius, light rain", + "paris": "24 degrees Celsius and sunny", + "london": "15 degrees Celsius, overcast", + "new york": "21 degrees Celsius, partly cloudy", + "san francisco": "17 degrees Celsius, foggy", +} def get_weather(location: str) -> str: - """Return the current weather for a city. Called by the model during the conversation.""" - return f"It is 22 degrees Celsius and sunny in {location}." + """Get the current weather for a city. Call this whenever the user asks about weather. + + Args: + location: The city name, e.g. "Tokyo" or "Paris". + """ + forecast = _FORECASTS.get(location.strip().lower(), "22 degrees Celsius and clear") + print(f" [tool] get_weather(location={location!r}) -> {forecast}") + return f"The weather in {location} is {forecast}." def build_app(): config = RealtimeConfig( model=MODEL, voice="Puck", - system_instruction="You are a friendly, concise voice assistant.", + system_instruction=( + "You are a friendly, concise voice assistant. When the user asks about the " + "weather, always call the get_weather tool and answer using its result. " + "Keep replies to one or two sentences." + ), ) - return AudioAgent(MODEL, realtime_config=config, tools=[get_weather]).compile() + return AudioAgent( + MODEL, realtime_config=config, tools=[get_weather], use_vertex_ai=USE_VERTEX + ).compile() async def main() -> None: @@ -55,11 +108,21 @@ async def main() -> None: queue = LiveInputQueue() loop = asyncio.get_running_loop() + # Mic gate: while the agent is speaking we drop mic frames so its own voice (played on + # the speaker and picked up by the mic) is never sent back as user input. Disabled in + # full-duplex mode (headphones), where barge-in is wanted instead. + agent_speaking = {"on": False} + def on_mic(indata, _frames, _time, _status) -> None: # PortAudio calls this on its own thread; marshal onto the event loop so the # asyncio-backed queue is touched only from the loop thread. + if agent_speaking["on"] and not FULL_DUPLEX: + return # muted while the agent talks (echo guard) loop.call_soon_threadsafe(queue.send_audio, bytes(indata)) + def unmute() -> None: + agent_speaking["on"] = False + speaker = sd.RawOutputStream(samplerate=OUTPUT_SAMPLE_RATE, channels=1, dtype="int16") mic = sd.RawInputStream( samplerate=INPUT_SAMPLE_RATE, @@ -70,22 +133,30 @@ def on_mic(indata, _frames, _time, _status) -> None: ) speaker.start() mic.start() - print("Listening. Speak into your mic; press Ctrl+C to stop.") + mode = "full-duplex (barge-in)" if FULL_DUPLEX else "echo-safe (mic muted while agent talks)" + print(f"Listening [{mode}]. Try: 'What's the weather in Tokyo?' (Ctrl+C to stop)") try: async for event in app.arealtime(queue, {"thread_id": "audio-mic-demo"}): if event.type == "audio_delta": - speaker.write(event.data) + agent_speaking["on"] = True # mute the mic for the duration of the reply + speaker.write(event.data) # play the model's voice + elif event.type == "turn_complete": + # Reopen the mic after the speaker has drained the buffered tail. + loop.call_later(MUTE_TAIL_SEC, unmute) elif event.type == "interrupted": - # Barge-in: discard audio already queued for playback. + # Barge-in (full-duplex only): discard audio already queued for playback. speaker.stop() speaker.start() + agent_speaking["on"] = False elif event.type == "input_transcript" and event.finished: print(f"you: {event.text}") elif event.type == "output_transcript" and event.finished: print(f"agent: {event.text}") elif event.type == "tool_call": - print(f"[tool] {event.name}({event.args})") + print(f" [tool-call requested] {event.name}({event.args})") + elif event.type == "error": + print(f" [error] {event.message}") except (KeyboardInterrupt, asyncio.CancelledError): print("\nStopping...") finally: @@ -98,7 +169,5 @@ def on_mic(indata, _frames, _time, _status) -> None: if __name__ == "__main__": - try: + with contextlib.suppress(KeyboardInterrupt): asyncio.run(main()) - except KeyboardInterrupt: - pass diff --git a/examples/realtime/graph.py b/examples/realtime/graph.py index 36c7c2b1..c7d52f55 100644 --- a/examples/realtime/graph.py +++ b/examples/realtime/graph.py @@ -13,14 +13,17 @@ import os -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv from agentflow.core.realtime.base import RealtimeConfig from agentflow.prebuilt.agent import AudioAgent from agentflow.storage.checkpointer import InMemoryCheckpointer -load_dotenv() +# Load .env reliably regardless of the launch directory. +_HERE = os.path.dirname(os.path.abspath(__file__)) +load_dotenv(os.path.join(_HERE, ".env")) +load_dotenv(find_dotenv(usecwd=True)) MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-live-2.5-flash-preview") diff --git a/tests/realtime/test_gemini_live.py b/tests/realtime/test_gemini_live.py index cf9706c0..1d0eb5dc 100644 --- a/tests/realtime/test_gemini_live.py +++ b/tests/realtime/test_gemini_live.py @@ -82,12 +82,41 @@ def test_input_and_output_transcription(self): for e in events ) - def test_interrupted_and_generation_complete(self): - content = _server_content(interrupted=True, generation_complete=True) + def test_interrupted_emitted(self): + content = _server_content(interrupted=True) events = normalize_message(_msg(server_content=content)) assert any(isinstance(e, InterruptedEvent) for e in events) + + def test_turn_complete_emits_turn_complete_event(self): + content = _server_content(turn_complete=True) + events = normalize_message(_msg(server_content=content)) assert any(isinstance(e, TurnCompleteEvent) for e in events) + def test_generation_complete_alone_is_not_a_turn_complete(self): + # generation_complete and turn_complete arrive in separate messages within one turn; + # only turn_complete is the authoritative end-of-turn signal. Mapping both would + # double-count turn boundaries. + content = _server_content(generation_complete=True) + events = normalize_message(_msg(server_content=content)) + assert not any(isinstance(e, TurnCompleteEvent) for e in events) + + def test_finished_transcript_with_no_text_still_emits_finish_marker(self): + # Gemini sends the terminating transcript chunk as finished=True, text=None. + # It must still surface so consumers can flush their accumulated transcript. + content = _server_content( + input_transcription=SimpleNamespace(text=None, finished=True), + output_transcription=SimpleNamespace(text=None, finished=True), + ) + events = normalize_message(_msg(server_content=content)) + assert any( + isinstance(e, InputTranscriptEvent) and e.text == "" and e.finished is True + for e in events + ) + assert any( + isinstance(e, OutputTranscriptEvent) and e.text == "" and e.finished is True + for e in events + ) + def test_tool_call_function_calls(self): fc = SimpleNamespace(id="call-1", name="get_weather", args={"city": "Paris"}) tool_call = SimpleNamespace(function_calls=[fc]) @@ -141,7 +170,10 @@ async def send_client_content(self, **kwargs): self.client_content = kwargs async def receive(self): - for m in self.scripted: + # Mirror the real SDK: receive() drains one turn's messages then completes; a + # subsequent call returns nothing (the client loops receive() across turns). + batch, self.scripted = self.scripted, [] + for m in batch: yield m @@ -211,6 +243,26 @@ async def test_reseed_history_maps_messages_to_send_client_content(self, config) assert [t.role for t in turns] == ["user", "model"] assert session.client_content["turn_complete"] is True + @pytest.mark.asyncio + async def test_reseed_skips_system_and_tool_roles(self, config): + from agentflow.core.state import Message + + session = FakeLiveSession() + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + await client.reseed_history( + [ + Message.text_message("be nice", role="system"), + Message.text_message("hi", role="user"), + Message.text_message("hello", role="assistant"), + ] + ) + + turns = session.client_content["turns"] + # system turn is dropped (set via system_instruction, not reseeded as dialogue) + assert [t.role for t in turns] == ["user", "model"] + @pytest.mark.asyncio async def test_close_exits_context_manager_and_is_idempotent(self, config): session = FakeLiveSession() @@ -225,6 +277,37 @@ async def test_close_exits_context_manager_and_is_idempotent(self, config): assert client.connected is False +class TestClientCredentialResolution: + def test_api_key_from_env_is_used(self, monkeypatch): + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + monkeypatch.setenv("GEMINI_API_KEY", "k-123") + client = GeminiLiveClient()._build_client() + assert client is not None # genai.Client built without raising + + def test_explicit_api_key_overrides_env(self, monkeypatch): + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + client = GeminiLiveClient(api_key="explicit")._build_client() + assert client is not None + + def test_missing_credentials_raises_clear_error(self, monkeypatch): + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + with pytest.raises(ValueError, match="GEMINI_API_KEY or GOOGLE_API_KEY"): + GeminiLiveClient()._build_client() + + def test_vertex_without_project_raises(self, monkeypatch): + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + with pytest.raises(ValueError, match="project"): + GeminiLiveClient(use_vertex_ai=True)._build_client() + + def test_vertex_uses_project_and_location(self, monkeypatch): + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "proj-x") + monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "europe-west1") + client = GeminiLiveClient(use_vertex_ai=True)._build_client() + assert client is not None + + class TestBuildConnectConfig: @pytest.mark.asyncio async def test_voice_and_transcription_mapped_into_live_config(self, config): @@ -388,3 +471,40 @@ async def test_receive_before_connect_raises(self, config): with pytest.raises(RuntimeError): async for _ in client.receive(): pass + + @pytest.mark.asyncio + async def test_receive_spans_multiple_turns(self, config): + # Gemini's session.receive() completes per turn; the client must loop it so a + # single receive() call streams events across several turns until the socket idles. + class MultiTurnSession: + def __init__(self, batches): + self.batches = list(batches) + + async def send_realtime_input(self, **kw): + pass + + async def receive(self): + if self.batches: + for m in self.batches.pop(0): + yield m + # exhausted -> yields nothing -> client loop stops + + async def __aenter__(self): + return self + + turn1 = [ + _msg(server_content=_server_content(model_turn=SimpleNamespace(parts=[_audio_part(b"\x01")]))), + _msg(server_content=_server_content(turn_complete=True)), + ] + turn2 = [ + _msg(server_content=_server_content(model_turn=SimpleNamespace(parts=[_audio_part(b"\x02")]))), + _msg(server_content=_server_content(turn_complete=True)), + ] + session = MultiTurnSession([turn1, turn2]) + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + kinds = [e.type async for e in client.receive()] + + # both turns streamed from one receive() call, then it stops cleanly + assert kinds == ["audio_delta", "turn_complete", "audio_delta", "turn_complete"] diff --git a/tests/realtime/test_live_agent.py b/tests/realtime/test_live_agent.py index c2f18cdb..842d2250 100644 --- a/tests/realtime/test_live_agent.py +++ b/tests/realtime/test_live_agent.py @@ -255,6 +255,132 @@ async def test_finished_transcripts_persist_as_messages_no_audio(self): assert {m.content[0].text for m in persisted} == {"hello", "hi there"} +class TestTranscriptAccumulation: + @pytest.mark.asyncio + async def test_partial_chunks_accumulate_and_flush_on_finished(self): + from agentflow.storage.checkpointer import InMemoryCheckpointer + + # Streamed as partials (finished=False) then a finish marker with empty text. + client = FakeRealtimeClient( + [ + OutputTranscriptEvent(text="Hello ", finished=False), + OutputTranscriptEvent(text="there ", finished=False), + OutputTranscriptEvent(text="friend.", finished=False), + OutputTranscriptEvent(text="", finished=True), + ] + ) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + cp = InMemoryCheckpointer() + state = AgentState() + config = {"thread_id": "t-acc", "user_id": "u1"} + + await _drain(agent, _closed_queue(), config, state=state, checkpointer=cp) + + persisted = await cp.alist_messages(config) + assert [m.content[0].text for m in persisted] == ["Hello there friend."] + + @pytest.mark.asyncio + async def test_finished_event_carries_full_text_to_consumer(self): + # The consumer should get one consolidated finished transcript with the whole text + # (not the empty finish marker), without having to accumulate partials itself. + client = FakeRealtimeClient( + [ + OutputTranscriptEvent(text="Hello ", finished=False), + OutputTranscriptEvent(text="world", finished=False), + OutputTranscriptEvent(text="", finished=True), + ] + ) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t-consolidate"}) + + finished = [ + e for e in events if e.type == "output_transcript" and e.finished + ] + assert len(finished) == 1 + assert finished[0].text == "Hello world" + + @pytest.mark.asyncio + async def test_interruption_discards_partial_so_restart_is_not_concatenated(self): + # A barge-in mid-transcript must drop the abandoned partial; the restarted turn's + # text must not be glued onto it. + client = FakeRealtimeClient( + [ + OutputTranscriptEvent(text="The weather is ", finished=False), + InterruptedEvent(), + OutputTranscriptEvent(text="It is sunny.", finished=False), + OutputTranscriptEvent(text="", finished=True), + ] + ) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t-barge"}) + + finished = [e for e in events if e.type == "output_transcript" and e.finished] + assert len(finished) == 1 + assert finished[0].text == "It is sunny." + + @pytest.mark.asyncio + async def test_empty_transcript_turn_emits_no_finished_event(self): + # A finished marker with nothing accumulated must not surface an empty transcript. + client = FakeRealtimeClient([OutputTranscriptEvent(text="", finished=True)]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + events = await _drain(agent, _closed_queue(), {"thread_id": "t-empty"}) + + assert not any(e.type == "output_transcript" for e in events) + + +class TestReseedGating: + @pytest.mark.asyncio + async def test_reseed_skipped_when_resumed_from_handle(self): + from agentflow.storage.checkpointer import InMemoryCheckpointer + from agentflow.utils.thread_info import ThreadInfo + + cp = InMemoryCheckpointer() + config = {"thread_id": "t-resumed", "user_id": "u1"} + await cp.aput_messages(config, [Message.text_message("earlier", role="user")]) + # A stored handle means the provider restores context on connect; reseed must NOT + # replay history again. + await cp.aput_thread( + config, ThreadInfo(thread_id="t-resumed", metadata={"resumption_handle": "H1"}) + ) + + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), config, checkpointer=cp) + + assert client.connected_with == ["H1"] + assert client.reseeded is None + + +class TestReconnectBackoff: + @pytest.mark.asyncio + async def test_error_driven_reconnect_gives_up_and_emits_fatal_error(self): + # Every socket drops on receive: error-driven reconnect must back off, cap, then + # surface a fatal ErrorEvent rather than spin forever. + class DroppingClient(FakeRealtimeClient): + async def receive(self): + raise ConnectionError("socket dropped") + yield # pragma: no cover - makes this an async generator + + def make(): + return DroppingClient() + + agent = LiveAgent(MODEL, realtime_client_factory=make) + agent._reconnect_base_delay = 0.0 # no real sleeping in the test + agent._reconnect_max_attempts = 3 + q = LiveInputQueue() # left open so the loop is allowed to attempt resume + + events = await _drain(agent, q, {"thread_id": "t-storm"}) + + fatal = [e for e in events if e.type == "error"] + assert len(fatal) == 1 + assert fatal[0].fatal is True + assert fatal[0].code == "reconnect_failed" + + class TestSessionConfig: @pytest.mark.asyncio async def test_per_session_overrides_merge_over_realtime_config(self): From c14d1c7c7b264b79ecb005d8e8e6fb7535c70bbd Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Tue, 16 Jun 2026 00:26:52 +0600 Subject: [PATCH 3/3] refactor: streamline response modalities initialization and improve type annotations in LiveAgent --- agentflow/core/realtime/base.py | 6 ++++- agentflow/core/realtime/live_agent.py | 35 ++++++++++++--------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/agentflow/core/realtime/base.py b/agentflow/core/realtime/base.py index c531ed1b..8ad099cf 100644 --- a/agentflow/core/realtime/base.py +++ b/agentflow/core/realtime/base.py @@ -128,6 +128,10 @@ class ErrorEvent(BaseModel): ResponseModality = Literal["AUDIO", "TEXT"] +def _default_modalities() -> list[ResponseModality]: + return ["AUDIO"] + + class VADConfig(BaseModel): """Voice-activity-detection settings. Disable for push-to-talk (manual activity).""" @@ -151,7 +155,7 @@ class RealtimeConfig(BaseModel): model_config = ConfigDict(validate_default=True) model: str - response_modalities: list[ResponseModality] = Field(default_factory=lambda: ["AUDIO"]) + response_modalities: list[ResponseModality] = Field(default_factory=_default_modalities) voice: str | None = None system_instruction: str | None = None input_audio_transcription: bool = True diff --git a/agentflow/core/realtime/live_agent.py b/agentflow/core/realtime/live_agent.py index 4e59ec0a..0556a09e 100644 --- a/agentflow/core/realtime/live_agent.py +++ b/agentflow/core/realtime/live_agent.py @@ -22,7 +22,7 @@ import contextlib import logging from collections.abc import AsyncIterator, Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from agentflow.core.graph.agent_internal.memory import AgentMemoryMixin from agentflow.core.graph.agent_internal.skills import AgentSkillsMixin @@ -35,6 +35,7 @@ OutputTranscriptEvent, RealtimeClient, RealtimeConfig, + ToolCallEvent, ToolResultEvent, ) from agentflow.core.realtime.providers.gemini_live import GeminiLiveClient @@ -174,9 +175,7 @@ async def arun( self._active_client = client # Only reseed when the provider did NOT restore context from a handle; otherwise the # model would receive the whole conversation twice (handle restore + reseed). - await self._maybe_reseed( - config, checkpointer, context_manager, resumed=handle is not None - ) + await self._maybe_reseed(config, checkpointer, context_manager, resumed=handle is not None) # Closing the input queue ends the session: the pump sets this when it drains the # close sentinel, and the receive loop stops once the provider goes idle. @@ -332,27 +331,25 @@ async def _handle_event( checkpointer: BaseCheckpointer | None, callback_manager: CallbackManager, ) -> list[RealtimeEvent]: - kind = event.type - - if kind == "tool_call": + if event.type == "tool_call": return await self._run_tool(event, config, state, callback_manager) - if kind == "input_transcript": + if event.type == "input_transcript": return await self._accumulate_transcript(event, "user", config, state, checkpointer) - if kind == "output_transcript": + if event.type == "output_transcript": return await self._accumulate_transcript( event, "assistant", config, state, checkpointer ) - if kind == "interrupted": + if event.type == "interrupted": # Barge-in discards the model's in-flight turn; drop any partial transcript so # the restarted turn isn't concatenated onto the abandoned one. self._input_transcript_buf = "" self._output_transcript_buf = "" self._publish_realtime(EventType.INTERRUPTED, config, ContentType.AUDIO, "barge_in") - elif kind == "go_away": + elif event.type == "go_away": self._publish_realtime(EventType.UPDATE, config, ContentType.UPDATE, "go_away") - elif kind == "session_update": + elif event.type == "session_update": self._resume_handle = event.resumption_handle await self._persist_handle(config, checkpointer) self._publish_realtime(EventType.UPDATE, config, ContentType.UPDATE, "session_resumed") @@ -361,8 +358,8 @@ async def _handle_event( async def _accumulate_transcript( self, - event: RealtimeEvent, - role: str, + event: InputTranscriptEvent | OutputTranscriptEvent, + role: Literal["user", "assistant"], config: dict[str, Any], state: AgentState, checkpointer: BaseCheckpointer | None, @@ -402,7 +399,7 @@ async def _accumulate_transcript( async def _run_tool( self, - event: RealtimeEvent, + event: ToolCallEvent, config: dict[str, Any], state: AgentState, callback_manager: CallbackManager, @@ -447,7 +444,7 @@ def _extract_tool_result(invoked: Any) -> dict[str, Any]: async def _persist_transcript( self, text: str, - role: str, + role: Literal["user", "assistant"], config: dict[str, Any], state: AgentState, checkpointer: BaseCheckpointer | None, @@ -519,7 +516,7 @@ async def _maybe_reseed( history = trimmed.context except Exception: logger.warning("context compression failed during reseed; using raw history") - if history: + if history and self._active_client is not None: await self._active_client.reseed_history(list(history)) async def _attempt_reconnect( @@ -552,9 +549,7 @@ async def _attempt_reconnect( fatal=True, ) - delay = min( - self._reconnect_base_delay * (2 ** (attempts - 1)), self._reconnect_max_delay - ) + delay = min(self._reconnect_base_delay * (2 ** (attempts - 1)), self._reconnect_max_delay) await asyncio.sleep(delay) with contextlib.suppress(Exception): await self._reconnect(rt)