From 82a58373cfab174358c004dd6d17388dc03b7dbd Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Tue, 16 Jun 2026 01:35:31 +0600 Subject: [PATCH 1/2] feat: add image handling and reconnect configuration to realtime components --- agentflow/core/realtime/__init__.py | 2 + agentflow/core/realtime/base.py | 19 ++ agentflow/core/realtime/live_agent.py | 118 +++++++++- .../core/realtime/providers/gemini_live.py | 5 + agentflow/core/realtime/queue.py | 15 +- agentflow/prebuilt/agent/audio.py | 23 +- agentflow/utils/callbacks.py | 70 ++++++ tests/realtime/test_gemini_live.py | 13 ++ tests/realtime/test_live_agent.py | 206 ++++++++++++++++++ tests/realtime/test_package_exports.py | 1 + tests/realtime/test_queue.py | 15 ++ 11 files changed, 469 insertions(+), 18 deletions(-) diff --git a/agentflow/core/realtime/__init__.py b/agentflow/core/realtime/__init__.py index e8bf65b..6227be8 100644 --- a/agentflow/core/realtime/__init__.py +++ b/agentflow/core/realtime/__init__.py @@ -17,6 +17,7 @@ RealtimeClient, RealtimeConfig, RealtimeEvent, + ReconnectConfig, SessionUpdateEvent, ToolCallEvent, ToolResultEvent, @@ -42,6 +43,7 @@ "RealtimeClient", "RealtimeConfig", "RealtimeEvent", + "ReconnectConfig", "SessionUpdateEvent", "ToolCallEvent", "ToolResultEvent", diff --git a/agentflow/core/realtime/base.py b/agentflow/core/realtime/base.py index 8ad099c..93aa56b 100644 --- a/agentflow/core/realtime/base.py +++ b/agentflow/core/realtime/base.py @@ -143,6 +143,20 @@ class VADConfig(BaseModel): silence_duration_ms: int | None = None +class ReconnectConfig(BaseModel): + """Reconnect/backoff policy for a dropped realtime socket. + + Provider-initiated ``go_away`` rotations always reconnect immediately (no backoff). Only + error-driven drops back off: attempt ``n`` waits ``min(base_delay * 2**(n-1), max_delay)`` + seconds, up to ``max_attempts`` tries before the session ends with a fatal error. Set + ``max_attempts=0`` to disable error-driven reconnect entirely. + """ + + base_delay: float = Field(default=0.5, ge=0.0) + max_delay: float = Field(default=10.0, ge=0.0) + max_attempts: int = Field(default=5, ge=0) + + class RealtimeConfig(BaseModel): """Per-session configuration handed to a :class:`RealtimeClient`. @@ -161,6 +175,7 @@ class RealtimeConfig(BaseModel): input_audio_transcription: bool = True output_audio_transcription: bool = True vad: VADConfig = Field(default_factory=VADConfig) + reconnect: ReconnectConfig = Field(default_factory=ReconnectConfig) context_window_compression: bool = False session_resumption: bool = True tools: list[Any] | None = None @@ -200,6 +215,10 @@ async def send_text(self, text: str) -> None: """Send a text turn into the live session.""" ... + async def send_image(self, data: bytes, mime_type: str) -> None: + """Send a single image frame (still image or video frame) into the live session.""" + ... + async def send_activity_start(self) -> None: """Manual-VAD / push-to-talk: mark the start of user activity.""" ... diff --git a/agentflow/core/realtime/live_agent.py b/agentflow/core/realtime/live_agent.py index 0556a09..8d1c9f0 100644 --- a/agentflow/core/realtime/live_agent.py +++ b/agentflow/core/realtime/live_agent.py @@ -43,6 +43,15 @@ from agentflow.runtime.publisher.events import ContentType, Event, EventModel, EventType from agentflow.runtime.publisher.publish import publish_event from agentflow.utils import CallbackManager +from agentflow.utils.callbacks import GraphLifecycleContext + + +# Event kinds that constitute model/user turn content. A turn starts on the first of these +# after a turn boundary and ends on turn_complete/interrupted; control frames (session_update, +# go_away, error) never open a turn. +_TURN_CONTENT_TYPES = frozenset( + {"audio_delta", "input_transcript", "output_transcript", "tool_call", "tool_result"} +) if TYPE_CHECKING: @@ -118,10 +127,12 @@ def __init__( self._output_transcript_buf = "" # Error-driven reconnect backoff (go_away reconnects are immediate; only transient - # drops back off). Instance attributes so tests can shrink them. - self._reconnect_base_delay = 0.5 - self._reconnect_max_delay = 10.0 - self._reconnect_max_attempts = 5 + # drops back off). Seeded from RealtimeConfig.reconnect; kept as instance attributes + # so tests can shrink them without rebuilding the config. + rc = self.realtime_config.reconnect + self._reconnect_base_delay = rc.base_delay + self._reconnect_max_delay = rc.max_delay + self._reconnect_max_attempts = rc.max_attempts # Builder mixins (no-op when their config is None). self._setup_memory(memory) @@ -150,7 +161,7 @@ def _resolve_tool_node(self) -> ToolNode | None: # ------------------------------------------------------------------ # # The duplex realtime loop. # ------------------------------------------------------------------ # - async def arun( + async def arun( # noqa: PLR0912, PLR0915 self, input_queue: LiveInputQueue, config: dict[str, Any], @@ -168,6 +179,7 @@ async def arun( self._output_transcript_buf = "" rt = self._session_realtime_config(config) rt = await self._resolve_session_tools(rt) + rt = await self._resolve_session_system_instruction(rt, state, config) handle = await self._load_resume_handle(config, checkpointer) client = self._client_factory() @@ -177,11 +189,17 @@ async def arun( # model would receive the whole conversation twice (handle restore + reseed). await self._maybe_reseed(config, checkpointer, context_manager, resumed=handle is not None) + # Session start mirrors a graph run: the LIVE node *is* the graph, so on_graph_start + # fires once here (before any turn) and on_graph_end once when the session ends. + state = await self._fire_graph_start(callback_manager, config, state) + # Closing the input queue ends the session: the pump sets this when it drains the # close sentinel, and the receive loop stops once the provider goes idle. stop_event = asyncio.Event() pump_task = asyncio.create_task(self._pump(input_queue, stop_event)) attempts = 0 # consecutive error-driven reconnect attempts (reset on healthy receive) + turn_index = 0 # 1-based count of turns started; doubles as on_graph_end total_steps + turn_active = False # a turn is open (content seen, no turn_complete/interrupt yet) try: while True: reconnect = False @@ -190,16 +208,29 @@ async def arun( try: async for event in self._receive_until_stop(self._active_client, stop_event): received_any = True + if not turn_active and event.type in _TURN_CONTENT_TYPES: + turn_index += 1 + turn_active = True + state = await self._fire_turn_start( + callback_manager, config, state, turn_index + ) for out in await self._handle_event( event, config, state, checkpointer, callback_manager ): yield out + if turn_active and event.type in ("turn_complete", "interrupted"): + state = await self._fire_turn_end( + callback_manager, config, state, turn_index + ) + turn_active = False if event.type == "go_away": reconnect = True forced = True break if event.type == "error" and getattr(event, "fatal", False): - return + # break (not return) so on_graph_end still fires for the session. + reconnect = False + break except Exception: # Transient drop: only resume if input is still open (avoid an # infinite reconnect storm once the session is shutting down). @@ -218,6 +249,11 @@ async def arun( if fatal is not None: yield fatal break + + # Balance a turn cut off by session end (no turn_complete arrived), then close out. + if turn_active: + state = await self._fire_turn_end(callback_manager, config, state, turn_index) + await self._fire_graph_end(callback_manager, config, state, turn_index) finally: pump_task.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -258,6 +294,42 @@ async def _resolve_session_tools(self, rt: RealtimeConfig) -> RealtimeConfig: return rt return rt.model_copy(update={"tools": schemas}) + async def _resolve_session_system_instruction( + self, rt: RealtimeConfig, state: AgentState, config: dict[str, Any] + ) -> RealtimeConfig: + """Flatten the agent's system prompt (+ skills + memory) into ``system_instruction``. + + Gemini Live takes a single ``system_instruction`` string fixed at connect time, so + the per-turn prompt list other agents send must be collapsed once, here. This is what + makes ``system_prompt``, the skills trigger table / session-mode content, and the + memory system prompt actually reach the model in realtime (the matching tools are + advertised separately by :meth:`_resolve_session_tools`). + + State-dependent pieces (session-mode skill from a state field, memory preload from the + latest user query) are therefore a connect-time snapshot, not re-evaluated per turn; + dynamic behaviour mid-session goes through ``set_skill`` / memory tools instead. + + ``{field}`` placeholders in the prompt content are interpolated from ``state`` exactly + like the turn-based path (via :func:`convert_messages`), so a system prompt that reads + from state behaves identically here. + """ + from agentflow.utils.converter import _interpolate_system_prompts + + base = list(self.system_prompt) + if not base and rt.system_instruction: + base = [{"role": "system", "content": rt.system_instruction}] + + prompts = self._build_skill_prompts(state, base) + prompts = prompts + await self._build_memory_prompts(state, config) + prompts = _interpolate_system_prompts(prompts, state) + + instruction = "\n\n".join( + str(p["content"]) for p in prompts if p.get("content") + ).strip() + if not instruction: + return rt + return rt.model_copy(update={"system_instruction": instruction}) + async def _receive_until_stop( self, client: RealtimeClient, stop_event: asyncio.Event ) -> AsyncIterator[RealtimeEvent]: @@ -307,6 +379,8 @@ async def _pump( await client.send_audio(item.data, item.sample_rate) elif item.kind == "text" and item.text is not None: await client.send_text(item.text) + elif item.kind == "image" and item.data is not None: + await client.send_image(item.data, item.mime_type or "image/jpeg") elif item.kind == "activity_start": await client.send_activity_start() elif item.kind == "activity_end": @@ -565,6 +639,38 @@ async def _reconnect(self, rt: RealtimeConfig) -> None: await client.connect(rt, resume_handle=self._resume_handle) self._active_client = client + # ------------------------------------------------------------------ # + # Lifecycle hooks (session == graph run; turn == one model generation). + # ------------------------------------------------------------------ # + async def _fire_graph_start( + self, cb: CallbackManager, config: dict[str, Any], state: AgentState + ) -> AgentState: + if not cb._lifecycle_hooks: + return state + return await cb.fire_on_graph_start(GraphLifecycleContext(config=config), state) + + async def _fire_graph_end( + self, cb: CallbackManager, config: dict[str, Any], state: AgentState, turns: int + ) -> None: + if not cb._lifecycle_hooks: + return + messages = list(getattr(state, "context", []) or []) + await cb.fire_on_graph_end(GraphLifecycleContext(config=config), state, messages, turns) + + async def _fire_turn_start( + self, cb: CallbackManager, config: dict[str, Any], state: AgentState, turn_index: int + ) -> AgentState: + if not cb._lifecycle_hooks: + return state + return await cb.fire_on_turn_start(GraphLifecycleContext(config=config), state, turn_index) + + async def _fire_turn_end( + self, cb: CallbackManager, config: dict[str, Any], state: AgentState, turn_index: int + ) -> AgentState: + if not cb._lifecycle_hooks: + return state + return await cb.fire_on_turn_end(GraphLifecycleContext(config=config), state, turn_index) + # ------------------------------------------------------------------ # # Observability for events ToolNode doesn't already publish. # ------------------------------------------------------------------ # diff --git a/agentflow/core/realtime/providers/gemini_live.py b/agentflow/core/realtime/providers/gemini_live.py index 93ffdcb..0fce1b9 100644 --- a/agentflow/core/realtime/providers/gemini_live.py +++ b/agentflow/core/realtime/providers/gemini_live.py @@ -301,6 +301,11 @@ async def send_text(self, text: str) -> None: session = self._require_session() await session.send_realtime_input(text=text) + async def send_image(self, data: bytes, mime_type: str = "image/jpeg") -> None: + session = self._require_session() + _, types = self._genai() + await session.send_realtime_input(media=types.Blob(data=data, mime_type=mime_type)) + async def send_activity_start(self) -> None: session = self._require_session() _, types = self._genai() diff --git a/agentflow/core/realtime/queue.py b/agentflow/core/realtime/queue.py index efd0fff..5f10980 100644 --- a/agentflow/core/realtime/queue.py +++ b/agentflow/core/realtime/queue.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -LiveInputKind = Literal["audio", "text", "activity_start", "activity_end", "close"] +LiveInputKind = Literal["audio", "text", "image", "activity_start", "activity_end", "close"] @dataclass(slots=True) @@ -31,13 +31,15 @@ class LiveInput: """A single upstream transport frame. Construct via ``LiveInputQueue.send_*``. ``kind`` discriminates the frame; only the fields relevant to that kind are set - (``data``/``sample_rate`` for audio, ``text`` for text, neither for control frames). + (``data``/``sample_rate`` for audio, ``data``/``mime_type`` for image, ``text`` for + text, none for control frames). """ kind: LiveInputKind data: bytes | None = None text: str | None = None sample_rate: int = INPUT_SAMPLE_RATE + mime_type: str | None = None class LiveInputQueue: @@ -71,6 +73,15 @@ def send_audio(self, data: bytes, sample_rate: int = INPUT_SAMPLE_RATE) -> None: def send_text(self, text: str) -> None: self._put(LiveInput(kind="text", text=text)) + def send_image(self, data: bytes, mime_type: str = "image/jpeg") -> None: + """Send a single image frame (e.g. a JPEG camera frame) into the live session. + + Gemini Live accepts still images and video as individual frames; send video as a + stream of frames (~1 fps is the model's effective ceiling). ``mime_type`` must be an + image type the provider supports (default ``image/jpeg``). + """ + self._put(LiveInput(kind="image", data=data, mime_type=mime_type)) + def send_activity_start(self) -> None: self._put(LiveInput(kind="activity_start")) diff --git a/agentflow/prebuilt/agent/audio.py b/agentflow/prebuilt/agent/audio.py index 5ea30d1..503c0fb 100644 --- a/agentflow/prebuilt/agent/audio.py +++ b/agentflow/prebuilt/agent/audio.py @@ -15,12 +15,13 @@ from agentflow.core.graph.tool_node import ToolNode from agentflow.core.realtime.base import RealtimeClient, RealtimeConfig from agentflow.core.realtime.live_agent import LiveAgent +from agentflow.core.skills.models import SkillConfig from agentflow.core.state.agent_state import AgentState from agentflow.core.state.base_context import BaseContextManager from agentflow.runtime.publisher.base_publisher import BasePublisher from agentflow.storage.checkpointer.base_checkpointer import BaseCheckpointer -from agentflow.storage.media.storage.base import BaseMediaStore from agentflow.storage.store.base_store import BaseStore +from agentflow.storage.store.memory_config import MemoryConfig from agentflow.utils.callbacks import CallbackManager from agentflow.utils.constants import END from agentflow.utils.id_generator import BaseIDGenerator, DefaultIDGenerator @@ -43,8 +44,8 @@ def __init__( # noqa: PLR0913 tools: Iterable[Callable] | None = None, client: Any = None, pass_user_info_to_mcp: bool = False, - skills: Any | None = None, - memory: Any | None = None, + skills: SkillConfig | None = None, + memory: MemoryConfig | None = None, realtime_client_factory: Callable[[], RealtimeClient] | None = None, live_node_name: str = "LIVE", **agent_kwargs: Any, @@ -76,7 +77,10 @@ def __init__( # noqa: PLR0913 @staticmethod def _build_tool_node( - *, tools: list[Callable], client: Any, pass_user_info_to_mcp: bool + *, + tools: list[Callable], + client: Any, + pass_user_info_to_mcp: bool, ) -> ToolNode | None: if not tools and client is None: return None @@ -103,23 +107,22 @@ def compile( self, checkpointer: BaseCheckpointer[StateT] | None = None, store: BaseStore | None = None, - interrupt_before: list[str] | None = None, - interrupt_after: list[str] | None = None, callback_manager: CallbackManager | None = None, - media_store: BaseMediaStore | None = None, shutdown_timeout: float = 30.0, ) -> CompiledGraph: + # No media_store: realtime media (images/video) is sent frame-by-frame straight to + # the live model via the input queue (see LiveInputQueue.send_image); it is never + # offloaded to or resolved from a media store, so the parameter would be dead here. self._configure_graph() + if self._graph is None: # pragma: no cover - _configure_graph always assigns raise RuntimeError("graph configuration failed") + return self._graph.compile( checkpointer=checkpointer, store=store, - interrupt_before=interrupt_before, - interrupt_after=interrupt_after, callback_manager=callback_manager if callback_manager is not None else CallbackManager(), - media_store=media_store, shutdown_timeout=shutdown_timeout, ) diff --git a/agentflow/utils/callbacks.py b/agentflow/utils/callbacks.py index 8e5b4cb..4aff46e 100644 --- a/agentflow/utils/callbacks.py +++ b/agentflow/utils/callbacks.py @@ -331,6 +331,34 @@ async def on_state_update( """ return None + async def on_turn_start( + self, + context: "GraphLifecycleContext", + state: StateT, + turn_index: int, + ) -> "StateT | None": + """Called when a new realtime conversation turn begins (realtime sessions only). + + A "turn" spans one model generation, bounded by the provider's turn-complete (or a + barge-in). ``turn_index`` is 1-based. Return a modified StateT to replace the current + state, or None to keep it. Never fired by turn-based invoke/stream runs. + """ + return None + + async def on_turn_end( + self, + context: "GraphLifecycleContext", + state: StateT, + turn_index: int, + ) -> "StateT | None": + """Called when a realtime conversation turn completes or is interrupted. + + ``turn_index`` is the 1-based index of the turn that just ended. Return a modified + StateT to replace the current state, or None to keep it. Never fired by turn-based + invoke/stream runs. + """ + return None + class CallbackManager: """ @@ -946,3 +974,45 @@ async def fire_on_state_update( ) ) return result + + async def fire_on_turn_start( + self, + context: GraphLifecycleContext, + state: StateT, + turn_index: int, + ) -> StateT: + """Fire all on_turn_start hooks and return the (potentially modified) state.""" + result = state + for hook in self._lifecycle_hooks: + try: + modified = await hook.on_turn_start(context, result, turn_index) + if modified is not None: + result = modified + except Exception as e: + logger.exception( + "Lifecycle hook %s.on_turn_start failed: %s", + hook.__class__.__name__, + e, + ) + return result + + async def fire_on_turn_end( + self, + context: GraphLifecycleContext, + state: StateT, + turn_index: int, + ) -> StateT: + """Fire all on_turn_end hooks and return the (potentially modified) state.""" + result = state + for hook in self._lifecycle_hooks: + try: + modified = await hook.on_turn_end(context, result, turn_index) + if modified is not None: + result = modified + except Exception as e: + logger.exception( + "Lifecycle hook %s.on_turn_end failed: %s", + hook.__class__.__name__, + e, + ) + return result diff --git a/tests/realtime/test_gemini_live.py b/tests/realtime/test_gemini_live.py index 1d0eb5d..c88467e 100644 --- a/tests/realtime/test_gemini_live.py +++ b/tests/realtime/test_gemini_live.py @@ -438,6 +438,19 @@ async def test_send_text_maps_to_realtime_input(self, config): assert session.sent_realtime[0]["text"] == "hello there" + @pytest.mark.asyncio + async def test_send_image_maps_to_media_blob(self, config): + session = FakeLiveSession() + client = GeminiLiveClient(connector=FakeConnector(session)) + await client.connect(config) + + await client.send_image(b"\xff\xd8\xff", mime_type="image/jpeg") + + assert len(session.sent_realtime) == 1 + blob = session.sent_realtime[0]["media"] + assert blob.data == b"\xff\xd8\xff" + assert blob.mime_type == "image/jpeg" + @pytest.mark.asyncio async def test_send_before_connect_raises(self, config): client = GeminiLiveClient(connector=FakeConnector(FakeLiveSession())) diff --git a/tests/realtime/test_live_agent.py b/tests/realtime/test_live_agent.py index 842d225..65b7a6b 100644 --- a/tests/realtime/test_live_agent.py +++ b/tests/realtime/test_live_agent.py @@ -29,6 +29,7 @@ from agentflow.runtime.publisher.base_publisher import BasePublisher from agentflow.utils import CallbackManager from agentflow.utils.background_task_manager import BackgroundTaskManager +from agentflow.utils.callbacks import GraphLifecycleHook MODEL = "gemini-2.5-flash-live" @@ -40,6 +41,7 @@ def __init__(self, events=None): self.connected_config = None self.sent_audio: list[tuple[bytes, int]] = [] self.sent_text: list[str] = [] + self.sent_images: list[tuple[bytes, str]] = [] self.activity: list[str] = [] self.tool_responses: list[tuple[str, str, object]] = [] self.reseeded = None @@ -55,6 +57,9 @@ async def send_audio(self, pcm, sample_rate): async def send_text(self, text): self.sent_text.append(text) + async def send_image(self, data, mime_type): + self.sent_images.append((data, mime_type)) + async def send_activity_start(self): self.activity.append("start") @@ -126,6 +131,7 @@ async def test_pump_maps_queue_frames_to_provider(self): q = LiveInputQueue() q.send_audio(b"x", sample_rate=16000) q.send_text("hi") + q.send_image(b"\xff\xd8\xff", mime_type="image/jpeg") q.send_activity_start() q.send_activity_end() q.close() @@ -134,6 +140,7 @@ async def test_pump_maps_queue_frames_to_provider(self): assert client.sent_audio == [(b"x", 16000)] assert client.sent_text == ["hi"] + assert client.sent_images == [(b"\xff\xd8\xff", "image/jpeg")] assert client.activity == ["start", "end"] @@ -380,6 +387,41 @@ def make(): assert fatal[0].fatal is True assert fatal[0].code == "reconnect_failed" + def test_reconnect_settings_seeded_from_realtime_config(self): + from agentflow.core.realtime.base import ReconnectConfig + + cfg = RealtimeConfig( + model=MODEL, + reconnect=ReconnectConfig(base_delay=0.1, max_delay=2.0, max_attempts=2), + ) + agent = LiveAgent(MODEL, realtime_config=cfg) + + assert agent._reconnect_base_delay == 0.1 + assert agent._reconnect_max_delay == 2.0 + assert agent._reconnect_max_attempts == 2 + + @pytest.mark.asyncio + async def test_max_attempts_zero_disables_error_driven_reconnect(self): + from agentflow.core.realtime.base import ReconnectConfig + + class DroppingClient(FakeRealtimeClient): + attempts = 0 + + async def receive(self): + DroppingClient.attempts += 1 + raise ConnectionError("socket dropped") + yield # pragma: no cover - makes this an async generator + + cfg = RealtimeConfig(model=MODEL, reconnect=ReconnectConfig(max_attempts=0)) + agent = LiveAgent(MODEL, realtime_config=cfg, realtime_client_factory=DroppingClient) + q = LiveInputQueue() # left open: reconnect would be allowed if not disabled + + events = await _drain(agent, q, {"thread_id": "t-no-retry"}) + + # The first drop is fatal immediately; the socket is never reopened. + assert DroppingClient.attempts == 1 + assert [e for e in events if e.type == "error"][0].code == "reconnect_failed" + class TestSessionConfig: @pytest.mark.asyncio @@ -408,6 +450,170 @@ async def test_no_overrides_uses_base_config_identity(self): assert client.connected_config is agent.realtime_config +class TestSystemInstruction: + @pytest.mark.asyncio + async def test_system_prompt_flattened_into_system_instruction(self): + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent( + MODEL, + system_prompt=[ + {"role": "system", "content": "You are a pirate."}, + {"role": "system", "content": "Always answer in one sentence."}, + ], + realtime_client_factory=_factory(client), + ) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert client.connected_config.system_instruction == ( + "You are a pirate.\n\nAlways answer in one sentence." + ) + + @pytest.mark.asyncio + async def test_explicit_system_instruction_preserved_when_no_system_prompt(self): + client = FakeRealtimeClient([TurnCompleteEvent()]) + cfg = RealtimeConfig(model=MODEL, system_instruction="Be terse.") + agent = LiveAgent(MODEL, realtime_config=cfg, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert client.connected_config.system_instruction == "Be terse." + + @pytest.mark.asyncio + async def test_no_system_prompt_leaves_instruction_unset(self): + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}) + + assert client.connected_config.system_instruction is None + + @pytest.mark.asyncio + async def test_system_prompt_interpolates_state_fields(self): + class _State(AgentState): + user_name: str = "" + + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent( + MODEL, + system_prompt=[{"role": "system", "content": "You assist {user_name}."}], + realtime_client_factory=_factory(client), + ) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}, state=_State(user_name="Ada")) + + assert client.connected_config.system_instruction == "You assist Ada." + + @pytest.mark.asyncio + async def test_session_skill_content_reaches_system_instruction(self, tmp_path): + from agentflow.core.skills.models import SkillConfig + + skill_dir = tmp_path / "weather" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: weather\ndescription: weather skill\n---\nCheck the forecast first.", + encoding="utf-8", + ) + + class _State(AgentState): + active_skill: str = "" + + client = FakeRealtimeClient([TurnCompleteEvent()]) + agent = LiveAgent( + MODEL, + system_prompt=[{"role": "system", "content": "Base prompt."}], + skills=SkillConfig(mode="session", preload_from="active_skill", skills_dir=str(tmp_path)), + realtime_client_factory=_factory(client), + ) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}, state=_State(active_skill="weather")) + + # Both the base prompt and the preloaded skill body reach the single instruction. + assert "Base prompt." in client.connected_config.system_instruction + assert "Check the forecast first." in client.connected_config.system_instruction + + +class _RecordingHook(GraphLifecycleHook): + def __init__(self): + self.calls = [] + + async def on_graph_start(self, ctx, state): + self.calls.append(("graph_start", None)) + + async def on_graph_end(self, ctx, final_state, messages, total_steps): + self.calls.append(("graph_end", total_steps)) + + async def on_turn_start(self, ctx, state, turn_index): + self.calls.append(("turn_start", turn_index)) + + async def on_turn_end(self, ctx, state, turn_index): + self.calls.append(("turn_end", turn_index)) + + +class TestLifecycleHooks: + @pytest.mark.asyncio + async def test_session_and_turn_hooks_fire_in_order(self): + cm = CallbackManager() + hook = _RecordingHook() + cm.register_lifecycle_hook(hook) + + # Two model turns, each: content then turn_complete. + client = FakeRealtimeClient( + [ + AudioDeltaEvent(data=b"\x01"), + TurnCompleteEvent(), + AudioDeltaEvent(data=b"\x02"), + TurnCompleteEvent(), + ] + ) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}, callback_manager=cm) + + assert hook.calls == [ + ("graph_start", None), + ("turn_start", 1), + ("turn_end", 1), + ("turn_start", 2), + ("turn_end", 2), + ("graph_end", 2), # total_steps == number of turns + ] + + @pytest.mark.asyncio + async def test_turn_cut_off_by_session_end_still_balances(self): + # Content arrives but no turn_complete before the session ends; on_turn_end must still + # fire so every on_turn_start is balanced, then on_graph_end closes the session. + cm = CallbackManager() + hook = _RecordingHook() + cm.register_lifecycle_hook(hook) + + client = FakeRealtimeClient([AudioDeltaEvent(data=b"\x01")]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}, callback_manager=cm) + + assert hook.calls == [ + ("graph_start", None), + ("turn_start", 1), + ("turn_end", 1), + ("graph_end", 1), + ] + + @pytest.mark.asyncio + async def test_control_only_session_fires_no_turn_hooks(self): + # A session with only control frames (no content) opens no turn. + cm = CallbackManager() + hook = _RecordingHook() + cm.register_lifecycle_hook(hook) + + client = FakeRealtimeClient([SessionUpdateEvent(resumption_handle="h1")]) + agent = LiveAgent(MODEL, realtime_client_factory=_factory(client)) + + await _drain(agent, _closed_queue(), {"thread_id": "t1"}, callback_manager=cm) + + assert hook.calls == [("graph_start", None), ("graph_end", 0)] + + class TestToolAdvertising: @pytest.mark.asyncio async def test_tool_node_tools_advertised_to_provider(self): diff --git a/tests/realtime/test_package_exports.py b/tests/realtime/test_package_exports.py index 8e4b73b..a0afed9 100644 --- a/tests/realtime/test_package_exports.py +++ b/tests/realtime/test_package_exports.py @@ -9,6 +9,7 @@ def test_public_symbols_are_exported(): "RealtimeConfig", "RealtimeEvent", "VADConfig", + "ReconnectConfig", "AudioDeltaEvent", "InputTranscriptEvent", "OutputTranscriptEvent", diff --git a/tests/realtime/test_queue.py b/tests/realtime/test_queue.py index 3ca238b..3b2675a 100644 --- a/tests/realtime/test_queue.py +++ b/tests/realtime/test_queue.py @@ -24,6 +24,21 @@ def test_send_text_enqueues_text_frame(self): assert item.kind == "text" assert item.text == "hello" + def test_send_image_enqueues_image_frame_with_mime(self): + q = LiveInputQueue() + q.send_image(b"\xff\xd8\xff") + item = q.get_nowait() + assert item.kind == "image" + assert item.data == b"\xff\xd8\xff" + assert item.mime_type == "image/jpeg" + + def test_send_image_accepts_custom_mime(self): + q = LiveInputQueue() + q.send_image(b"\x89PNG", mime_type="image/png") + item = q.get_nowait() + assert item.kind == "image" + assert item.mime_type == "image/png" + def test_activity_markers_enqueue_control_frames(self): q = LiveInputQueue() q.send_activity_start() From e84c9c3584685d6445595d4e95135d0da8ba640a Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Tue, 16 Jun 2026 01:40:31 +0600 Subject: [PATCH 2/2] refactor: simplify instruction generation in LiveAgent --- agentflow/core/realtime/live_agent.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/agentflow/core/realtime/live_agent.py b/agentflow/core/realtime/live_agent.py index 8d1c9f0..aed2ef1 100644 --- a/agentflow/core/realtime/live_agent.py +++ b/agentflow/core/realtime/live_agent.py @@ -323,9 +323,7 @@ async def _resolve_session_system_instruction( prompts = prompts + await self._build_memory_prompts(state, config) prompts = _interpolate_system_prompts(prompts, state) - instruction = "\n\n".join( - str(p["content"]) for p in prompts if p.get("content") - ).strip() + instruction = "\n\n".join(str(p["content"]) for p in prompts if p.get("content")).strip() if not instruction: return rt return rt.model_copy(update={"system_instruction": instruction})