Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agentflow/core/realtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RealtimeClient,
RealtimeConfig,
RealtimeEvent,
ReconnectConfig,
SessionUpdateEvent,
ToolCallEvent,
ToolResultEvent,
Expand All @@ -42,6 +43,7 @@
"RealtimeClient",
"RealtimeConfig",
"RealtimeEvent",
"ReconnectConfig",
"SessionUpdateEvent",
"ToolCallEvent",
"ToolResultEvent",
Expand Down
19 changes: 19 additions & 0 deletions agentflow/core/realtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ class VADConfig(BaseModel):
silence_duration_ms: int | None = None


class ReconnectConfig(BaseModel):
"""Reconnect/backoff policy for a dropped realtime socket.

Provider-initiated ``go_away`` rotations always reconnect immediately (no backoff). Only
error-driven drops back off: attempt ``n`` waits ``min(base_delay * 2**(n-1), max_delay)``
seconds, up to ``max_attempts`` tries before the session ends with a fatal error. Set
``max_attempts=0`` to disable error-driven reconnect entirely.
"""

base_delay: float = Field(default=0.5, ge=0.0)
max_delay: float = Field(default=10.0, ge=0.0)
max_attempts: int = Field(default=5, ge=0)


class RealtimeConfig(BaseModel):
"""Per-session configuration handed to a :class:`RealtimeClient`.

Expand All @@ -161,6 +175,7 @@ class RealtimeConfig(BaseModel):
input_audio_transcription: bool = True
output_audio_transcription: bool = True
vad: VADConfig = Field(default_factory=VADConfig)
reconnect: ReconnectConfig = Field(default_factory=ReconnectConfig)
context_window_compression: bool = False
session_resumption: bool = True
tools: list[Any] | None = None
Expand Down Expand Up @@ -200,6 +215,10 @@ async def send_text(self, text: str) -> None:
"""Send a text turn into the live session."""
...

async def send_image(self, data: bytes, mime_type: str) -> None:
"""Send a single image frame (still image or video frame) into the live session."""
...
Comment thread
Iamsdt marked this conversation as resolved.
Dismissed

async def send_activity_start(self) -> None:
"""Manual-VAD / push-to-talk: mark the start of user activity."""
...
Expand Down
116 changes: 110 additions & 6 deletions agentflow/core/realtime/live_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
from agentflow.runtime.publisher.events import ContentType, Event, EventModel, EventType
from agentflow.runtime.publisher.publish import publish_event
from agentflow.utils import CallbackManager
from agentflow.utils.callbacks import GraphLifecycleContext


# Event kinds that constitute model/user turn content. A turn starts on the first of these
# after a turn boundary and ends on turn_complete/interrupted; control frames (session_update,
# go_away, error) never open a turn.
_TURN_CONTENT_TYPES = frozenset(
{"audio_delta", "input_transcript", "output_transcript", "tool_call", "tool_result"}
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -118,10 +127,12 @@ def __init__(
self._output_transcript_buf = ""

# Error-driven reconnect backoff (go_away reconnects are immediate; only transient
# drops back off). Instance attributes so tests can shrink them.
self._reconnect_base_delay = 0.5
self._reconnect_max_delay = 10.0
self._reconnect_max_attempts = 5
# drops back off). Seeded from RealtimeConfig.reconnect; kept as instance attributes
# so tests can shrink them without rebuilding the config.
rc = self.realtime_config.reconnect
self._reconnect_base_delay = rc.base_delay
self._reconnect_max_delay = rc.max_delay
self._reconnect_max_attempts = rc.max_attempts

# Builder mixins (no-op when their config is None).
self._setup_memory(memory)
Expand Down Expand Up @@ -150,7 +161,7 @@ def _resolve_tool_node(self) -> ToolNode | None:
# ------------------------------------------------------------------ #
# The duplex realtime loop.
# ------------------------------------------------------------------ #
async def arun(
async def arun( # noqa: PLR0912, PLR0915
self,
input_queue: LiveInputQueue,
config: dict[str, Any],
Expand All @@ -168,6 +179,7 @@ async def arun(
self._output_transcript_buf = ""
rt = self._session_realtime_config(config)
rt = await self._resolve_session_tools(rt)
rt = await self._resolve_session_system_instruction(rt, state, config)

handle = await self._load_resume_handle(config, checkpointer)
client = self._client_factory()
Expand All @@ -177,11 +189,17 @@ async def arun(
# model would receive the whole conversation twice (handle restore + reseed).
await self._maybe_reseed(config, checkpointer, context_manager, resumed=handle is not None)

# Session start mirrors a graph run: the LIVE node *is* the graph, so on_graph_start
# fires once here (before any turn) and on_graph_end once when the session ends.
state = await self._fire_graph_start(callback_manager, config, state)

# Closing the input queue ends the session: the pump sets this when it drains the
# close sentinel, and the receive loop stops once the provider goes idle.
stop_event = asyncio.Event()
pump_task = asyncio.create_task(self._pump(input_queue, stop_event))
attempts = 0 # consecutive error-driven reconnect attempts (reset on healthy receive)
turn_index = 0 # 1-based count of turns started; doubles as on_graph_end total_steps
turn_active = False # a turn is open (content seen, no turn_complete/interrupt yet)
try:
while True:
reconnect = False
Expand All @@ -190,16 +208,29 @@ async def arun(
try:
async for event in self._receive_until_stop(self._active_client, stop_event):
received_any = True
if not turn_active and event.type in _TURN_CONTENT_TYPES:
turn_index += 1
turn_active = True
state = await self._fire_turn_start(
callback_manager, config, state, turn_index
)
for out in await self._handle_event(
event, config, state, checkpointer, callback_manager
):
yield out
if turn_active and event.type in ("turn_complete", "interrupted"):
state = await self._fire_turn_end(
callback_manager, config, state, turn_index
)
turn_active = False
if event.type == "go_away":
reconnect = True
forced = True
break
if event.type == "error" and getattr(event, "fatal", False):
return
# break (not return) so on_graph_end still fires for the session.
reconnect = False
break
except Exception:
# Transient drop: only resume if input is still open (avoid an
# infinite reconnect storm once the session is shutting down).
Expand All @@ -218,6 +249,11 @@ async def arun(
if fatal is not None:
yield fatal
break

# Balance a turn cut off by session end (no turn_complete arrived), then close out.
if turn_active:
state = await self._fire_turn_end(callback_manager, config, state, turn_index)
await self._fire_graph_end(callback_manager, config, state, turn_index)
finally:
pump_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
Expand Down Expand Up @@ -258,6 +294,40 @@ async def _resolve_session_tools(self, rt: RealtimeConfig) -> RealtimeConfig:
return rt
return rt.model_copy(update={"tools": schemas})

async def _resolve_session_system_instruction(
self, rt: RealtimeConfig, state: AgentState, config: dict[str, Any]
) -> RealtimeConfig:
"""Flatten the agent's system prompt (+ skills + memory) into ``system_instruction``.

Gemini Live takes a single ``system_instruction`` string fixed at connect time, so
the per-turn prompt list other agents send must be collapsed once, here. This is what
makes ``system_prompt``, the skills trigger table / session-mode content, and the
memory system prompt actually reach the model in realtime (the matching tools are
advertised separately by :meth:`_resolve_session_tools`).

State-dependent pieces (session-mode skill from a state field, memory preload from the
latest user query) are therefore a connect-time snapshot, not re-evaluated per turn;
dynamic behaviour mid-session goes through ``set_skill`` / memory tools instead.

``{field}`` placeholders in the prompt content are interpolated from ``state`` exactly
like the turn-based path (via :func:`convert_messages`), so a system prompt that reads
from state behaves identically here.
"""
from agentflow.utils.converter import _interpolate_system_prompts

base = list(self.system_prompt)
if not base and rt.system_instruction:
base = [{"role": "system", "content": rt.system_instruction}]

prompts = self._build_skill_prompts(state, base)
prompts = prompts + await self._build_memory_prompts(state, config)
prompts = _interpolate_system_prompts(prompts, state)

instruction = "\n\n".join(str(p["content"]) for p in prompts if p.get("content")).strip()
if not instruction:
return rt
return rt.model_copy(update={"system_instruction": instruction})

async def _receive_until_stop(
self, client: RealtimeClient, stop_event: asyncio.Event
) -> AsyncIterator[RealtimeEvent]:
Expand Down Expand Up @@ -307,6 +377,8 @@ async def _pump(
await client.send_audio(item.data, item.sample_rate)
elif item.kind == "text" and item.text is not None:
await client.send_text(item.text)
elif item.kind == "image" and item.data is not None:
await client.send_image(item.data, item.mime_type or "image/jpeg")
elif item.kind == "activity_start":
await client.send_activity_start()
elif item.kind == "activity_end":
Expand Down Expand Up @@ -565,6 +637,38 @@ async def _reconnect(self, rt: RealtimeConfig) -> None:
await client.connect(rt, resume_handle=self._resume_handle)
self._active_client = client

# ------------------------------------------------------------------ #
# Lifecycle hooks (session == graph run; turn == one model generation).
# ------------------------------------------------------------------ #
async def _fire_graph_start(
self, cb: CallbackManager, config: dict[str, Any], state: AgentState
) -> AgentState:
if not cb._lifecycle_hooks:
return state
return await cb.fire_on_graph_start(GraphLifecycleContext(config=config), state)

async def _fire_graph_end(
self, cb: CallbackManager, config: dict[str, Any], state: AgentState, turns: int
) -> None:
if not cb._lifecycle_hooks:
return
messages = list(getattr(state, "context", []) or [])
await cb.fire_on_graph_end(GraphLifecycleContext(config=config), state, messages, turns)

async def _fire_turn_start(
self, cb: CallbackManager, config: dict[str, Any], state: AgentState, turn_index: int
) -> AgentState:
if not cb._lifecycle_hooks:
return state
return await cb.fire_on_turn_start(GraphLifecycleContext(config=config), state, turn_index)

async def _fire_turn_end(
self, cb: CallbackManager, config: dict[str, Any], state: AgentState, turn_index: int
) -> AgentState:
if not cb._lifecycle_hooks:
return state
return await cb.fire_on_turn_end(GraphLifecycleContext(config=config), state, turn_index)

# ------------------------------------------------------------------ #
# Observability for events ToolNode doesn't already publish.
# ------------------------------------------------------------------ #
Expand Down
5 changes: 5 additions & 0 deletions agentflow/core/realtime/providers/gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ async def send_text(self, text: str) -> None:
session = self._require_session()
await session.send_realtime_input(text=text)

async def send_image(self, data: bytes, mime_type: str = "image/jpeg") -> None:
session = self._require_session()
_, types = self._genai()
await session.send_realtime_input(media=types.Blob(data=data, mime_type=mime_type))

async def send_activity_start(self) -> None:
session = self._require_session()
_, types = self._genai()
Expand Down
15 changes: 13 additions & 2 deletions agentflow/core/realtime/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@

logger = logging.getLogger(__name__)

LiveInputKind = Literal["audio", "text", "activity_start", "activity_end", "close"]
LiveInputKind = Literal["audio", "text", "image", "activity_start", "activity_end", "close"]


@dataclass(slots=True)
class LiveInput:
"""A single upstream transport frame. Construct via ``LiveInputQueue.send_*``.

``kind`` discriminates the frame; only the fields relevant to that kind are set
(``data``/``sample_rate`` for audio, ``text`` for text, neither for control frames).
(``data``/``sample_rate`` for audio, ``data``/``mime_type`` for image, ``text`` for
text, none for control frames).
"""

kind: LiveInputKind
data: bytes | None = None
text: str | None = None
sample_rate: int = INPUT_SAMPLE_RATE
mime_type: str | None = None


class LiveInputQueue:
Expand Down Expand Up @@ -71,6 +73,15 @@ def send_audio(self, data: bytes, sample_rate: int = INPUT_SAMPLE_RATE) -> None:
def send_text(self, text: str) -> None:
self._put(LiveInput(kind="text", text=text))

def send_image(self, data: bytes, mime_type: str = "image/jpeg") -> None:
"""Send a single image frame (e.g. a JPEG camera frame) into the live session.

Gemini Live accepts still images and video as individual frames; send video as a
stream of frames (~1 fps is the model's effective ceiling). ``mime_type`` must be an
image type the provider supports (default ``image/jpeg``).
"""
self._put(LiveInput(kind="image", data=data, mime_type=mime_type))

def send_activity_start(self) -> None:
self._put(LiveInput(kind="activity_start"))

Expand Down
23 changes: 13 additions & 10 deletions agentflow/prebuilt/agent/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from agentflow.core.graph.tool_node import ToolNode
from agentflow.core.realtime.base import RealtimeClient, RealtimeConfig
from agentflow.core.realtime.live_agent import LiveAgent
from agentflow.core.skills.models import SkillConfig
from agentflow.core.state.agent_state import AgentState
from agentflow.core.state.base_context import BaseContextManager
from agentflow.runtime.publisher.base_publisher import BasePublisher
from agentflow.storage.checkpointer.base_checkpointer import BaseCheckpointer
from agentflow.storage.media.storage.base import BaseMediaStore
from agentflow.storage.store.base_store import BaseStore
from agentflow.storage.store.memory_config import MemoryConfig
from agentflow.utils.callbacks import CallbackManager
from agentflow.utils.constants import END
from agentflow.utils.id_generator import BaseIDGenerator, DefaultIDGenerator
Expand All @@ -43,8 +44,8 @@ def __init__( # noqa: PLR0913
tools: Iterable[Callable] | None = None,
client: Any = None,
pass_user_info_to_mcp: bool = False,
skills: Any | None = None,
memory: Any | None = None,
skills: SkillConfig | None = None,
memory: MemoryConfig | None = None,
realtime_client_factory: Callable[[], RealtimeClient] | None = None,
live_node_name: str = "LIVE",
**agent_kwargs: Any,
Expand Down Expand Up @@ -76,7 +77,10 @@ def __init__( # noqa: PLR0913

@staticmethod
def _build_tool_node(
*, tools: list[Callable], client: Any, pass_user_info_to_mcp: bool
*,
tools: list[Callable],
client: Any,
pass_user_info_to_mcp: bool,
) -> ToolNode | None:
if not tools and client is None:
return None
Expand All @@ -103,23 +107,22 @@ def compile(
self,
checkpointer: BaseCheckpointer[StateT] | None = None,
store: BaseStore | None = None,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
callback_manager: CallbackManager | None = None,
media_store: BaseMediaStore | None = None,
shutdown_timeout: float = 30.0,
) -> CompiledGraph:
# No media_store: realtime media (images/video) is sent frame-by-frame straight to
# the live model via the input queue (see LiveInputQueue.send_image); it is never
# offloaded to or resolved from a media store, so the parameter would be dead here.
self._configure_graph()

if self._graph is None: # pragma: no cover - _configure_graph always assigns
raise RuntimeError("graph configuration failed")

return self._graph.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
callback_manager=callback_manager
if callback_manager is not None
else CallbackManager(),
media_store=media_store,
shutdown_timeout=shutdown_timeout,
)
Loading
Loading