diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 16f6e78a..361a7040 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -1,8 +1,13 @@ from __future__ import annotations -from typing import Any +import copy +import logging +from typing import Any, Sequence from databricks.sdk import WorkspaceClient +from databricks_ai_bridge.tool_repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT + +logger = logging.getLogger(__name__) try: from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebaseClient, LakebasePool @@ -16,6 +21,109 @@ _checkpoint_imports_available = False +try: + from langchain_core.messages import AIMessage, ToolMessage + + _message_imports_available = True +except ImportError: + AIMessage = object # type: ignore + ToolMessage = object # type: ignore + _message_imports_available = False + + +def _build_tool_resume_repair(messages: Sequence[Any]) -> list[Any]: + """Build synthetic ``ToolMessage`` responses for orphan tool calls. + + Internal helper used by ``_repair_loaded_checkpoint_tuple``. When a + LangGraph run is killed mid-tool, the checkpointer preserves the + trailing ``AIMessage.tool_calls`` but the paired ``ToolMessage``s + never land. Replaying that state to the LLM fails because the API + (Anthropic in particular) requires every ``tool_use`` to be + immediately followed by a matching ``tool_result``. + + Walks the trailing assistant turn (the last contiguous block of + ``AIMessage`` / ``ToolMessage``) and returns a synthetic + ``ToolMessage`` for each ``tool_call`` id that lacks a matching + ``ToolMessage.tool_call_id``. The caller appends these to the + ``messages`` channel before the next model call. + """ + if not _message_imports_available or not messages: + return [] + + # Trailing assistant turn: walk backwards until we hit a non-assistant/ + # non-tool message. That block is the "pending" turn whose tool_use ↔ + # tool_result pairing we need to enforce. + trailing_start = len(messages) + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], (AIMessage, ToolMessage)): + trailing_start = i + else: + break + + tool_call_ids: list[str] = [] + answered: set[str] = set() + for msg in messages[trailing_start:]: + if isinstance(msg, AIMessage): + for tc in getattr(msg, "tool_calls", None) or []: + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + if tc_id and tc_id not in tool_call_ids: + tool_call_ids.append(tc_id) + elif isinstance(msg, ToolMessage): + tcid = getattr(msg, "tool_call_id", None) + if tcid: + answered.add(tcid) + + orphans = [tc_id for tc_id in tool_call_ids if tc_id not in answered] + return [ + ToolMessage(tool_call_id=tc_id, content=DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT) + for tc_id in orphans + ] + + +def _repair_loaded_checkpoint_tuple(tup: Any) -> Any: + """Return a copy of ``tup`` with orphan tool_calls in its ``messages`` + channel closed by synthetic ``ToolMessage`` s. + + Called on every ``(a)get_tuple`` to make the served checkpoint + protocol-valid (every ``tool_use`` paired with a ``tool_result``) + transparently. A kill between the ``model`` and ``tools`` nodes leaves + the trailing ``AIMessage.tool_calls`` unpaired; on the NEXT turn that + state would otherwise leak into the LLM and be rejected by the + provider's pairing check. + + Idempotent — ``_build_tool_resume_repair`` is a no-op when state is + already clean. Cheap — the walk is O(trailing-turn). + + Side effect: the synthetic ``ToolMessage`` s added here become part of + the state LangGraph writes on the NEXT node boundary, so the repair + self-heals the DB row over time rather than re-computing on every read. + """ + if tup is None or not _message_imports_available: + return tup + + checkpoint = getattr(tup, "checkpoint", None) + if not isinstance(checkpoint, dict): + return tup + channel_values = checkpoint.get("channel_values") + if not isinstance(channel_values, dict): + return tup + messages = channel_values.get("messages") + if not isinstance(messages, list) or not messages: + return tup + + repair = _build_tool_resume_repair(messages) + if not repair: + return tup + + logger.info( + "[durable] checkpoint read-time repair: injected %d synthetic ToolMessage(s)", + len(repair), + ) + new_checkpoint = copy.copy(checkpoint) + new_checkpoint["channel_values"] = dict(channel_values) + new_checkpoint["channel_values"]["messages"] = list(messages) + list(repair) + return tup._replace(checkpoint=new_checkpoint) + class CheckpointSaver(PostgresSaver): """ @@ -68,6 +176,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._lakebase.close() return False + def get_tuple(self, config): + """Return the checkpoint tuple, with trailing orphan tool_calls paired.""" + return _repair_loaded_checkpoint_tuple(super().get_tuple(config)) + class AsyncCheckpointSaver(AsyncPostgresSaver): """ @@ -122,3 +234,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): """Exit async context manager and close the connection pool.""" await self._lakebase.close() return False + + async def aget_tuple(self, config): + """Return the checkpoint tuple, with trailing orphan tool_calls paired.""" + return _repair_loaded_checkpoint_tuple(await super().aget_tuple(config)) diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index 75f8ce9d..2b6f871b 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -446,3 +446,82 @@ async def test_async_checkpoint_saver_branch_resource_path(monkeypatch): assert "host=auto-db-host" in test_pool.conninfo assert saver._lakebase._is_autoscaling is True + + +class TestReadTimeCheckpointRepair: + """Read-time repair: aget_tuple / get_tuple returns a state where every + trailing ``AIMessage.tool_calls`` is paired with a ``ToolMessage``. Keeps + user-space free of middleware when the app is built on our savers.""" + + def _make_tuple(self, messages): + from collections import namedtuple + + FakeTuple = namedtuple( + "CheckpointTuple", + ["config", "checkpoint", "metadata", "parent_config", "pending_writes"], + ) + return FakeTuple( + config={}, + checkpoint={ + "v": 1, + "id": "ckpt", + "channel_values": {"messages": list(messages)}, + }, + metadata={}, + parent_config=None, + pending_writes=None, + ) + + def test_repairs_trailing_orphan_tool_call(self): + from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + tup = self._make_tuple( + [ + HumanMessage("hi"), + AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]), + ] + ) + repaired = _repair_loaded_checkpoint_tuple(tup) + msgs = repaired.checkpoint["channel_values"]["messages"] + assert len(msgs) == 3 + assert isinstance(msgs[-1], ToolMessage) + assert msgs[-1].tool_call_id == "c1" + + def test_noop_when_state_is_clean(self): + from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + tup = self._make_tuple( + [ + HumanMessage("hi"), + AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]), + ToolMessage(tool_call_id="c1", content="ok"), + AIMessage(content="done"), + ] + ) + repaired = _repair_loaded_checkpoint_tuple(tup) + # No repair added → tuple unchanged. + assert repaired is tup + + def test_none_tuple_passes_through(self): + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + assert _repair_loaded_checkpoint_tuple(None) is None + + def test_does_not_mutate_original_messages_list(self): + from langchain_core.messages import AIMessage, HumanMessage + + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + original_messages = [ + HumanMessage("hi"), + AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]), + ] + tup = self._make_tuple(original_messages) + original_len = len(original_messages) + _repair_loaded_checkpoint_tuple(tup) + # Calling repair must NOT mutate the caller's original list. + assert len(original_messages) == original_len diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 94e4fcd8..5850892d 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -43,17 +43,27 @@ async def main(): DEFAULT_TOKEN_CACHE_DURATION_SECONDS, AsyncLakebaseSQLAlchemy, ) + from databricks_ai_bridge.tool_repair import sanitize_tool_items _session_imports_available = True except ImportError: SQLAlchemySession = object # type: ignore DEFAULT_TOKEN_CACHE_DURATION_SECONDS = None # type: ignore DEFAULT_POOL_RECYCLE_SECONDS = None # type: ignore + sanitize_tool_items = None # type: ignore _session_imports_available = False logger = logging.getLogger(__name__) +def _sanitize_items(items: list[Any]) -> list[Any]: + """Session-scoped wrapper around :func:`sanitize_tool_items` that only + sets the log prefix. Kept as a one-liner so existing + ``self._sanitize_items`` call sites stay stable. + """ + return sanitize_tool_items(items, log_prefix="[durable] session items sanitized") + + class AsyncDatabricksSession(SQLAlchemySession): """ Async OpenAI Agents SDK Session implementation for Databricks Lakebase. @@ -179,6 +189,18 @@ async def _ensure_tables(self) -> None: await self._lakebase.create_schema() await super()._ensure_tables() + async def get_items(self, limit: Optional[int] = None) -> list[Any]: + """Return session items, always repaired for protocol validity. + + The returned list has every ``function_call`` paired with a + ``function_call_output`` — orphans from a durable-resume crash get + a synthetic output appended, and duplicates get deduped. The + underlying DB rows are not modified; this is a pure in-memory + filter, cheap to re-run on every call. + """ + items = await super().get_items(limit=limit) + return _sanitize_items(items) + @classmethod def _build_cache_key( cls, diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 335a915b..4b2246eb 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1289,6 +1289,116 @@ def test_init_branch_resource_path_resolves_host( ) +class TestSanitizeItems: + """Pure walker that reconciles orphan function_call / function_call_output + items. Shared by both the destructive ``repair()`` path and the read-time + ``get_items()`` filter.""" + + def _items_for(self, *types_and_ids): + # Helper: build items from (type, call_id) tuples. + items = [] + for spec in types_and_ids: + if isinstance(spec, str): + items.append({"role": "user", "content": spec}) + else: + t, cid = spec + items.append( + {"type": t, "call_id": cid, "name": "f", "arguments": "{}"} + if t == "function_call" + else {"type": t, "call_id": cid, "output": "ok"} + ) + return items + + def test_noop_when_clean_returns_same_list(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for( + "hi", + ("function_call", "c1"), + ("function_call_output", "c1"), + "done", + ) + out = _sanitize_items(items) + assert out is items # caller can skip re-persistence + + def test_injects_synthetic_output_for_orphan_call(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for("hi", ("function_call", "c1")) + out = _sanitize_items(items) + assert len(out) == 3 + assert out[-1]["type"] == "function_call_output" + assert out[-1]["call_id"] == "c1" + + def test_injects_for_multiple_orphan_calls(self): + # Scenario the user hit: multiple parallel tool_calls, all orphaned. + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for( + "hi", + ("function_call", "c1"), + ("function_call", "c2"), + ("function_call", "c3"), + ) + out = _sanitize_items(items) + calls = [i for i in out if i.get("type") == "function_call"] + outputs = [i for i in out if i.get("type") == "function_call_output"] + assert len(calls) == 3 + assert len(outputs) == 3 + assert {o["call_id"] for o in outputs} == {"c1", "c2", "c3"} + + def test_drops_orphan_output_with_no_matching_call(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for("hi", ("function_call_output", "ghost")) + out = _sanitize_items(items) + assert all(i.get("type") != "function_call_output" for i in out) + + def test_dedupes_duplicate_calls_and_outputs(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for( + ("function_call", "c1"), + ("function_call", "c1"), + ("function_call_output", "c1"), + ("function_call_output", "c1"), + ) + out = _sanitize_items(items) + assert len(out) == 2 + + +class TestAsyncGetItemsAutoRepair: + """get_items() always applies read-time repair. Uses a minimal subclass + that bypasses parent SQLAlchemySession init so we can exercise the + override without a DB.""" + + def _fake_session(self, items): + from databricks_openai.agents.session import AsyncDatabricksSession, _sanitize_items + + class _FakeSession(AsyncDatabricksSession): + def __init__(self, stored): + # Bypass parent init — only need the stored items. + self._stored = stored + + async def get_items(self, limit=None): + return _sanitize_items(list(self._stored)) + + return _FakeSession(items) + + @pytest.mark.asyncio + async def test_auto_repair_injects_synthetic_outputs(self): + sess = self._fake_session( + [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call", "call_id": "c2", "name": "f", "arguments": "{}"}, + ] + ) + items = await sess.get_items() + synth = [i for i in items if i.get("type") == "function_call_output"] + assert len(synth) == 2 + + # ============================================================================= # Schema Tests # ============================================================================= diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index 903d466f..aed2c903 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -79,7 +79,51 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {AGENT_DB_SCHEMA}")) await conn.run_sync(Base.metadata.create_all) + # Idempotent migration for tables created by earlier versions: add any + # columns introduced for durable-resume support. Each statement runs in + # its own transaction so an InsufficientPrivilege on one ALTER (another + # pod's SP owns the table but the schema is already migrated) doesn't + # poison the rest. A single mega-transaction would abort entirely on the + # first owner-check failure even with IF NOT EXISTS. + migration_stmts = ( + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS original_request TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.messages " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"CREATE INDEX IF NOT EXISTS idx_responses_stale " + f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) " + "WHERE status = 'in_progress'", + ) + skipped_migrations: list[str] = [] + for stmt in migration_stmts: + try: + async with _engine.begin() as conn: + await conn.execute(text(stmt)) + except Exception as exc: + msg = str(exc).lower() + if "insufficientprivilege" in msg or "must be owner" in msg: + skipped_migrations.append(stmt.split("\n")[0]) + continue + raise + _initialized = True + if skipped_migrations: + # WARN-level summary: if the DB was previously migrated by another SP + # this is fine, but if it's genuinely a new table and our SP lacks + # ALTER, claim/heartbeat queries will fail later with a confusing + # "column does not exist" — surface it clearly at startup. + logger.warning( + "[DB] Skipped %d durability migration(s) due to insufficient " + "privilege — assuming table was already migrated by another " + "service principal. Crash-resume will fail with 'column does " + "not exist' if this assumption is wrong. Skipped: %s", + len(skipped_migrations), + ", ".join(skipped_migrations), + ) logger.info("[DB] Engine and schema ready") diff --git a/src/databricks_ai_bridge/long_running/models.py b/src/databricks_ai_bridge/long_running/models.py index 1d876dc7..7014a7db 100644 --- a/src/databricks_ai_bridge/long_running/models.py +++ b/src/databricks_ai_bridge/long_running/models.py @@ -14,7 +14,12 @@ class Base(DeclarativeBase): class Response(Base): - """Response status tracking for background agent tasks.""" + """Response status tracking for background agent tasks. + + Durability columns (``owner_pod_id``, ``heartbeat_at``, ``attempt_number``, + ``original_request``) support crash-resume: another pod can atomically + claim a stale in-progress row and replay the agent loop. + """ __tablename__ = "responses" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -25,12 +30,23 @@ class Response(Base): DateTime(timezone=True), nullable=False, server_default=func.now() ) trace_id: Mapped[str | None] = mapped_column(Text, nullable=True) + owner_pod_id: Mapped[str | None] = mapped_column(Text, nullable=True) + heartbeat_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) + original_request: Mapped[str | None] = mapped_column(Text, nullable=True) messages = relationship("Message", back_populates="response", cascade="all, delete-orphan") class Message(Base): - """Stream events and output items for a response.""" + """Stream events and output items for a response. + + ``attempt_number`` tags events by which run attempt emitted them so that + resumed runs append to the same event log without overwriting earlier + (abandoned) attempts, and retrieve can filter to the latest attempt only. + """ __tablename__ = "messages" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -44,6 +60,9 @@ class Message(Base): sequence_number: Mapped[int] = mapped_column( Integer, primary_key=True, nullable=False, default=0 ) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) item: Mapped[str | None] = mapped_column(Text, nullable=True) stream_event: Mapped[str | None] = mapped_column(Text, nullable=True) diff --git a/src/databricks_ai_bridge/long_running/repository.py b/src/databricks_ai_bridge/long_running/repository.py index fcd86b29..06d30edb 100644 --- a/src/databricks_ai_bridge/long_running/repository.py +++ b/src/databricks_ai_bridge/long_running/repository.py @@ -5,15 +5,37 @@ from typing import Any, NamedTuple from sqlalchemy import select, update +from sqlalchemy.sql import bindparam, text from databricks_ai_bridge.long_running.db import session_scope -from databricks_ai_bridge.long_running.models import Message, Response +from databricks_ai_bridge.long_running.models import AGENT_DB_SCHEMA, Message, Response -async def create_response(response_id: str, status: str) -> None: - """Insert a new response.""" +async def create_response( + response_id: str, + status: str, + *, + owner_pod_id: str | None = None, + original_request: dict[str, Any] | None = None, +) -> None: + """Insert a new response row. + + ``owner_pod_id`` and ``original_request`` are optional so that non-durable + callers (tests, legacy flows) can still create rows without durability + metadata. When present, they enable heartbeat + crash-resume semantics. + """ async with session_scope() as session: - session.add(Response(response_id=response_id, status=status)) + session.add( + Response( + response_id=response_id, + status=status, + owner_pod_id=owner_pod_id, + heartbeat_at=datetime.now().astimezone() if owner_pod_id else None, + original_request=( + json.dumps(original_request) if original_request is not None else None + ), + ) + ) await session.commit() @@ -43,18 +65,84 @@ async def update_response_trace_id(response_id: str, trace_id: str) -> None: await session.commit() +async def heartbeat_response(response_id: str, pod_id: str) -> bool: + """Update heartbeat_at for a response IFF this pod owns it. + + Returns True on success. A False result means the claim has been lost + (another pod took over, or the run finished and heartbeat should stop). + """ + async with session_scope() as session: + stmt = ( + update(Response) + .where(Response.response_id == response_id, Response.owner_pod_id == pod_id) + .values(heartbeat_at=datetime.now().astimezone()) + ) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 + + +async def claim_stale_response( + response_id: str, + new_owner_pod_id: str, + stale_threshold_seconds: float, +) -> int | None: + """Atomically claim an in-progress response whose heartbeat has gone stale. + + Uses a single conditional UPDATE so exactly one caller wins on contention: + claim only succeeds if status is ``in_progress`` AND + (``owner_pod_id IS NULL`` OR ``heartbeat_at`` is older than the threshold). + + Returns the new ``attempt_number`` on success, or ``None`` if the row did + not satisfy the claim conditions (already completed, already claimed by a + live pod, or nonexistent). + """ + # Raw SQL because SQLAlchemy's ORM-level update doesn't expose RETURNING for + # the incremented column as ergonomically. Using a single statement keeps the + # claim atomic without an explicit transaction-level lock. + stmt = text( + f""" + UPDATE {AGENT_DB_SCHEMA}.responses + SET owner_pod_id = :pod, + heartbeat_at = now(), + attempt_number = attempt_number + 1 + WHERE response_id = :rid + AND status = 'in_progress' + AND (owner_pod_id IS NULL + OR heartbeat_at IS NULL + OR heartbeat_at < now() - make_interval(secs => :threshold)) + RETURNING attempt_number + """ + ).bindparams( + bindparam("pod", type_=None), + bindparam("rid", type_=None), + bindparam("threshold", type_=None), + ) + async with session_scope() as session: + result = await session.execute( + stmt, + {"pod": new_owner_pod_id, "rid": response_id, "threshold": stale_threshold_seconds}, + ) + row = result.first() + await session.commit() + return int(row[0]) if row else None + + async def append_message( response_id: str, sequence_number: int, item: str | None = None, stream_event: dict[str, Any] | None = None, + *, + attempt_number: int = 1, ) -> None: - """Append a message (stream event) for a response.""" + """Append a message (stream event) for a response, tagged with attempt_number.""" async with session_scope() as session: session.add( Message( response_id=response_id, sequence_number=sequence_number, + attempt_number=attempt_number, item=item, stream_event=json.dumps(stream_event) if stream_event is not None else None, ) @@ -65,22 +153,26 @@ async def append_message( async def get_messages( response_id: str, after_sequence: int | None = None, -) -> list[tuple[int, str | None, dict[str, Any] | None]]: - """Fetch messages for a response, optionally after a sequence number. + *, + attempt_number: int | None = None, +) -> list[tuple[int, str | None, dict[str, Any] | None, int]]: + """Fetch messages for a response, optionally filtering by sequence / attempt. - Returns list of (sequence_number, item, stream_event_dict). + Returns list of ``(sequence_number, item, stream_event_dict, attempt_number)``. """ async with session_scope() as session: stmt = select(Message).where(Message.response_id == response_id) if after_sequence is not None: stmt = stmt.where(Message.sequence_number > after_sequence) + if attempt_number is not None: + stmt = stmt.where(Message.attempt_number == attempt_number) stmt = stmt.order_by(Message.sequence_number) result = await session.execute(stmt) rows = result.scalars().all() out = [] for r in rows: evt = json.loads(r.stream_event) if r.stream_event else None - out.append((r.sequence_number, r.item, evt)) + out.append((r.sequence_number, r.item, evt, r.attempt_number)) return out @@ -89,6 +181,10 @@ class ResponseInfo(NamedTuple): status: str created_at: datetime trace_id: str | None + owner_pod_id: str | None + heartbeat_at: datetime | None + attempt_number: int + original_request: dict[str, Any] | None async def get_response(response_id: str) -> ResponseInfo | None: @@ -97,5 +193,14 @@ async def get_response(response_id: str) -> ResponseInfo | None: result = await session.execute(select(Response).where(Response.response_id == response_id)) row = result.scalar_one_or_none() if row: - return ResponseInfo(row.response_id, row.status, row.created_at, row.trace_id) + return ResponseInfo( + row.response_id, + row.status, + row.created_at, + row.trace_id, + row.owner_pod_id, + row.heartbeat_at, + row.attempt_number, + json.loads(row.original_request) if row.original_request else None, + ) return None diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index b3374d67..5d10dfdd 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -6,9 +6,12 @@ raise RuntimeError("The long_running module requires Python 3.11 or later.") import asyncio +import copy import inspect import json import logging +import os +import socket import time import uuid from collections.abc import AsyncGenerator @@ -34,19 +37,25 @@ from databricks_ai_bridge.long_running.db import dispose_db, init_db, is_db_configured from databricks_ai_bridge.long_running.repository import ( append_message, + claim_stale_response, create_response, get_messages, get_response, + heartbeat_response, update_response_status, update_response_trace_id, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings +from databricks_ai_bridge.tool_repair import sanitize_tool_items from databricks_ai_bridge.utils.annotations import experimental logger = logging.getLogger(__name__) BACKGROUND_KEY = "background" +# One ID per process so heartbeats + claims have a stable owner identity. +_POD_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" + async def _deferred_mark_failed( response_id: str, delay: float = 2.0, reason: str = "Task timed out" @@ -65,7 +74,8 @@ async def _deferred_mark_failed( # or SELECT FOR UPDATE on the response row to serialise writers. async with asyncio.timeout(delay): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) error_event = { "type": "error", @@ -75,7 +85,13 @@ async def _deferred_mark_failed( "code": "task_timeout", }, } - await append_message(response_id, next_seq, item=None, stream_event=error_event) + await append_message( + response_id, + next_seq, + item=None, + stream_event=error_event, + attempt_number=attempt, + ) await update_response_status(response_id, "failed") logger.info("Marked %s as failed (reason: %s)", response_id, reason) @@ -91,10 +107,21 @@ async def _deferred_mark_failed( ) +async def _current_attempt(response_id: str) -> int: + """Fetch the current attempt_number for a response, defaulting to 1.""" + resp = await get_response(response_id) + return resp.attempt_number if resp else 1 + + def _sse_event(event_type: str, data: dict[str, Any] | str) -> str: - """Format an SSE event per Open Responses spec.""" + """Emit ``data:``-only SSE frames. Match the non-durable stream format + so downstream SSE parsers dispatch on the payload's ``type`` field + rather than a leading ``event:`` name line. Claude's multi-response + stream (one response.created/completed pair per tool iteration) plus + the event-name prefix confuses the AI SDK's Databricks provider into + a retry loop.""" payload = data if isinstance(data, str) else json.dumps(data) - return f"event: {event_type}\ndata: {payload}\n\n" + return f"data: {payload}\n\n" def _age_seconds(created_at: datetime) -> float: @@ -105,9 +132,224 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() +# Tool-pair items are collected into the main inheritance bucket and +# preserved in event-log order. Narrative ``message`` items are routed +# separately in the collector (see ``_collect_prior_attempt_tool_events``) +# so they can be hoisted past the tool pairs for Anthropic adapter +# compatibility — adding ``"message"`` here would break that hoist. +_TOOL_PAIR_TYPES = ("function_call", "function_call_output") + + +def _iter_attempt_events(messages: list[tuple], attempt: int): + """Yield ``(event_type, event_dict)`` pairs for the given attempt. + + Skips rows from other attempts and non-dict event payloads so callers + can write single-concern walkers without repeating the same filter. + """ + for _seq, _item_json, evt, attempt_tag in messages: + if attempt_tag != attempt: + continue + if not isinstance(evt, dict): + continue + yield evt.get("type"), evt + + +def _extract_completed_items(messages: list[tuple], attempt: int) -> tuple[list[dict], list[dict]]: + """Scan ``.done`` events and partition into (tool pairs, narrative).""" + tool_items: list[dict] = [] + narrative_items: list[dict] = [] + for t, evt in _iter_attempt_events(messages, attempt): + if t != "response.output_item.done": + continue + item = evt.get("item") + if not isinstance(item, dict): + continue + itype = item.get("type") + if itype == "message": + narrative_items.append(item) + elif itype in _TOOL_PAIR_TYPES: + tool_items.append(item) + return tool_items, narrative_items + + +def _reassemble_partial_message(messages: list[tuple], attempt: int) -> dict | None: + """Return a synthetic assistant message if the attempt ended with a + never-completed in-flight text item, else ``None``. + + Tracks ``output_item.added`` for message items, accumulates their + ``output_text.delta`` frames, and drops the tracker when a matching + ``.done`` arrives (that item is authoritative). Anything left at the + end is an unfinished message whose deltas we stitch into a synthetic + item so the next attempt's LLM can continue the prior narration. + """ + in_progress_text: dict[str, list[str]] = {} + in_progress_order: list[str] = [] + for t, evt in _iter_attempt_events(messages, attempt): + if t == "response.output_item.added": + item = evt.get("item") + if isinstance(item, dict) and item.get("type") == "message": + iid = item.get("id") + if iid: + in_progress_text.setdefault(iid, []) + if iid not in in_progress_order: + in_progress_order.append(iid) + elif t == "response.output_item.done": + item = evt.get("item") + if isinstance(item, dict): + iid = item.get("id") + if iid in in_progress_text: + del in_progress_text[iid] + in_progress_order.remove(iid) + elif t == "response.output_text.delta": + iid = evt.get("item_id") + delta = evt.get("delta") + if iid and isinstance(delta, str) and iid in in_progress_text: + in_progress_text[iid].append(delta) + + for iid in in_progress_order: + chunks = in_progress_text.get(iid) or [] + if not chunks: + continue + return { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "".join(chunks), "annotations": []}, + ], + } + return None + + +def _collect_prior_attempt_tool_events( + messages: list[tuple], prior_attempt_number: int +) -> list[dict]: + """Return items the given prior attempt emitted, reordered to be a + valid provider message sequence on replay. + + Composition: (completed tool pairs in event order) → (completed + narrative messages in event order) → (partial reassembled message, if + any). Claude's raw stream interleaves narrative ``message`` items + between each ``function_call`` and its ``function_call_output``, which + in Anthropic format would look like ``assistant(tool_use)`` → + ``assistant(text)`` → ``user(tool_result)`` and trip the provider's + "tool_use must be immediately followed by tool_result" rule (HTTP + 400). Hoisting narrative past the tool pairs keeps each function_call + adjacent to its output and lets the narrative flow as a trailing + assistant block. + + ``messages`` is the repository's tuples ``(seq, item_json, + stream_event, attempt_number)``. + """ + tool_items, narrative_items = _extract_completed_items(messages, prior_attempt_number) + partial = _reassemble_partial_message(messages, prior_attempt_number) + if partial is not None: + narrative_items.append(partial) + return tool_items + narrative_items + + +def _sanitize_request_input(request_dict: dict[str, Any]) -> dict[str, Any]: + """Reconcile orphaned function_call / function_call_output items in + ``request['input']`` via the shared :func:`sanitize_tool_items` walker. + + Walks the whole history (not just the trailing turn) because UI-echoed + history can carry orphans from prior crashed turns mid-list. + Mutates ``request_dict['input']`` in place and returns the dict for + caller convenience. + """ + items = request_dict.get("input") + if not isinstance(items, list) or not items: + return request_dict + request_dict["input"] = sanitize_tool_items(items, log_prefix="[durable] input sanitized") + return request_dict + + +def _rotate_conversation_id( + request_dict: dict[str, Any], + new_attempt_number: int, + response_id: str, +) -> dict[str, Any]: + """Rotate the conversation anchor to a per-attempt value. + + After a crash, attempt N+1 should see a FRESH checkpointer / session so it + doesn't inherit mid-turn state that the SDK can't repair cleanly (most + notably the LangGraph stream-event attempt-boundary orphan artifact). + The handler's priority chain is: + + 1. custom_inputs.thread_id / session_id (explicit, wins) + 2. context.conversation_id (fallback) + 3. auto-generated (last resort) + + We drop (1), pick the current base anchor, and write ``{base}::attempt-N`` + into (2). The handler then resolves to a fresh key for this attempt while + still being deterministic across retries of the same attempt. + + The LLM sees full turn history via ``original_request.input``, which was + captured at the initial POST — before any attempt ran, so it's clean by + construction. + """ + custom_inputs = request_dict.get("custom_inputs") + if not isinstance(custom_inputs, dict): + custom_inputs = {} + + base_anchor = ( + custom_inputs.get("thread_id") + or custom_inputs.get("session_id") + or (request_dict.get("context") or {}).get("conversation_id") + or response_id + ) + + custom_inputs.pop("thread_id", None) + custom_inputs.pop("session_id", None) + request_dict["custom_inputs"] = custom_inputs + + ctx = request_dict.get("context") or {} + ctx = dict(ctx) + rotated = f"{base_anchor}::attempt-{new_attempt_number}" + ctx["conversation_id"] = rotated + request_dict["context"] = ctx + logger.info( + "[durable] rotated conversation_id for resume response_id=%s attempt=%d base=%s rotated=%s", + response_id, + new_attempt_number, + base_anchor, + rotated, + ) + return request_dict + + +def _inject_conversation_id(request_dict: dict[str, Any], response_id: str) -> dict[str, Any]: + """Anchor the request to ``response_id`` as its conversation. + + Operates on a plain dict — the caller is responsible for converting to/from + pydantic via ``model_dump()`` and the server's validator. + + Templates that back this server use ``context.conversation_id`` (and + ``custom_inputs.thread_id`` / ``custom_inputs.session_id``) as priority-2 + fallbacks to derive their stateful thread/session key. If neither is + provided by the client, a resumed invocation from another pod would + generate a *fresh* ID and miss the checkpoint entirely — so we stamp the + conversation_id here before persisting the request, guaranteeing that + every replay hits the same memory store. + + Client-supplied values take precedence and are left untouched. + """ + out = copy.deepcopy(request_dict) if request_dict else {} + custom_inputs = out.get("custom_inputs") or {} + if custom_inputs.get("thread_id") or custom_inputs.get("session_id"): + return out + ctx = out.get("context") or {} + if ctx.get("conversation_id"): + return out + ctx = dict(ctx) + ctx["conversation_id"] = response_id + out["context"] = ctx + return out + + @experimental class LongRunningAgentServer(AgentServer): - """AgentServer subclass adding background mode and retrieve endpoints. + """AgentServer subclass adding background mode, retrieve endpoints, and + durable resume. Only compatible with ``ResponsesAgent`` mode. @@ -125,6 +367,16 @@ class LongRunningAgentServer(AgentServer): ``LAKEBASE_INSTANCE_NAME``, ``LAKEBASE_AUTOSCALING_ENDPOINT``, or both ``LAKEBASE_AUTOSCALING_PROJECT`` and ``LAKEBASE_AUTOSCALING_BRANCH``. + Durable resume: when ``GET /responses/{id}`` sees an ``in_progress`` run + whose owning pod has stopped heartbeating for more than + ``heartbeat_stale_threshold_seconds``, the retrieving pod atomically claims + the run and re-invokes the registered handler with a rotated + ``conversation_id`` (so the agent SDK resolves to a fresh thread/session), + the original request's ``input`` enriched with the prior attempt's already + emitted tool calls / outputs / narrative, and an ``[INTERRUPTED]`` synthetic + output paired with any tool call that didn't finish. Completed work is + preserved; only the interrupted step re-runs. + Args: enable_chat_proxy: Whether to enable the chat proxy endpoint. db_instance_name: Lakebase provisioned instance name. Overrides @@ -143,6 +395,12 @@ class LongRunningAgentServer(AgentServer): Defaults to 5000 (5 seconds). cleanup_timeout_seconds: Timeout for DB cleanup after task failure. Defaults to 7.0. + heartbeat_interval_seconds: How often the owning pod writes + ``heartbeat_at`` while a run is in flight. Defaults to 3.0. + heartbeat_stale_threshold_seconds: Age at which a heartbeat is + considered stale and another pod may claim the run. Also used + as the grace window for a freshly-created run that hasn't + written its first heartbeat yet. Defaults to 10.0. """ _SUPPORTED_AGENT_TYPE = "ResponsesAgent" @@ -162,6 +420,8 @@ def __init__( poll_interval_seconds: float = 1.0, db_statement_timeout_ms: int = 5000, cleanup_timeout_seconds: float = 7.0, + heartbeat_interval_seconds: float = 3.0, + heartbeat_stale_threshold_seconds: float = 10.0, ): if agent_type != self._SUPPORTED_AGENT_TYPE: raise ValueError( @@ -173,11 +433,18 @@ def __init__( poll_interval_seconds=poll_interval_seconds, db_statement_timeout_ms=db_statement_timeout_ms, cleanup_timeout_seconds=cleanup_timeout_seconds, + heartbeat_interval_seconds=heartbeat_interval_seconds, + heartbeat_stale_threshold_seconds=heartbeat_stale_threshold_seconds, ) self._db_instance_name = db_instance_name self._db_autoscaling_endpoint = db_autoscaling_endpoint self._db_project = db_project self._db_branch = db_branch + # Track in-flight background tasks per response_id so the debug-kill + # endpoint can simulate a pod crash without tearing the whole pod + # down. Not load-bearing for correctness — durability still relies on + # DB state, this is just a test affordance. + self._running_tasks: dict[str, asyncio.Task] = {} super().__init__(agent_type, enable_chat_proxy=enable_chat_proxy) def _setup_routes(self) -> None: @@ -195,6 +462,41 @@ async def cancel_endpoint(response_id: str): detail="Cancellation is not yet implemented.", ) + # Debug endpoint for testing durable resume: cancels the in-flight + # asyncio task that owns the given response_id WITHOUT running the + # _task_scope cleanup, so the DB row stays in_progress with a + # going-stale heartbeat — exactly the shape a real pod crash leaves. + # Opt-in via env var so it's never exposed in production. + if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") == "1": + + @self.app.post("/_debug/kill_task/{response_id}") + async def _debug_kill_task(response_id: str): + task = self._running_tasks.get(response_id) + if task is None: + logger.info( + "[durable] kill endpoint: no task response_id=%s on pod=%s", + response_id, + _POD_ID, + ) + raise HTTPException( + status_code=404, + detail=( + "No in-flight task for that response_id on this pod " + "(may already have finished or be running on another pod)." + ), + ) + logger.info( + "[durable] kill endpoint: cancelling task response_id=%s pod=%s", + response_id, + _POD_ID, + ) + task.cancel() + return { + "response_id": response_id, + "pod_id": _POD_ID, + "status": "task_cancelled", + } + db_configured = is_db_configured() @self.app.get("/responses/{response_id}") @@ -265,6 +567,9 @@ async def _handle_invocations_request( data = {k: v for k, v in data.items() if k not in (BACKGROUND_KEY, MLFLOW_STREAM_KEY)} return_trace_id = (get_request_headers().get(RETURN_TRACE_HEADER) or "").lower() == "true" + if self._settings.auto_sanitize_input: + data = _sanitize_request_input(data) + try: request_data = self.validator.validate_and_convert_request(data) except ValueError as e: @@ -290,11 +595,29 @@ async def _handle_background_request( ) -> dict[str, Any] | StreamingResponse: """Start a new conversation and return response_id immediately.""" response_id = f"resp_{uuid.uuid4().hex[:24]}" - await create_response(response_id, "in_progress") + # Anchor the conversation to response_id so any future replay from a + # different pod resolves to the same agent-SDK thread/session. We + # round-trip through dict + validator so the handler still receives a + # pydantic ResponsesAgentRequest (its declared arg type). The + # declared param type is ``dict`` but the runtime object is a pydantic + # model from ``validate_and_convert_request``; fall back to ``dict()`` + # when tests pass a plain dict directly. + dump = getattr(request_data, "model_dump", None) + request_dict = dump() if callable(dump) else dict(request_data) + durable_dict = _inject_conversation_id(request_dict, response_id) + durable_request = self.validator.validate_and_convert_request(durable_dict) + await create_response( + response_id, + "in_progress", + owner_pod_id=_POD_ID, + original_request=durable_dict, + ) - logger.debug( - "Background response created", - extra={"response_id": response_id, "stream": is_streaming}, + logger.info( + "Background response created response_id=%s stream=%s pod=%s", + response_id, + is_streaming, + _POD_ID, ) response_obj: dict[str, Any] = { @@ -309,21 +632,99 @@ async def _handle_background_request( } # Fire-and-forget is intentional — task status is persisted to the database. + # We still track the task handle so the debug-kill endpoint can simulate + # a crash (and so we know whether a claim target lives on this pod). if is_streaming: - asyncio.create_task( - self._run_background_stream(response_id, request_data, return_trace_id) + task = asyncio.create_task( + self._run_background_stream( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) + self._track_task(response_id, task) return await self._handle_retrieve_request( response_id, stream=True, starting_after=0, ) else: - asyncio.create_task( - self._run_background_invoke(response_id, request_data, return_trace_id) + task = asyncio.create_task( + self._run_background_invoke( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) + self._track_task(response_id, task) return response_obj + def _track_task(self, response_id: str, task: asyncio.Task) -> None: + """Record a background task so the debug-kill endpoint can find it.""" + self._running_tasks[response_id] = task + task.add_done_callback(lambda _t: self._running_tasks.pop(response_id, None)) + + @asynccontextmanager + async def _heartbeat(self, response_id: str) -> AsyncGenerator[None, None]: + """Keep the response row's heartbeat_at fresh while the body runs. + + A background task writes ``heartbeat_at = now()`` every + ``heartbeat_interval_seconds`` for the owning pod. It stops when the + body returns/raises. Heartbeat write failures are logged but do not + interrupt the agent run — the stale-run check will detect a dead pod. + """ + interval = self._settings.heartbeat_interval_seconds + stop = asyncio.Event() + + async def _beat(): + beats = 0 + logger.info( + "[durable] heartbeat start response_id=%s pod=%s interval=%.1fs", + response_id, + _POD_ID, + interval, + ) + try: + while not stop.is_set(): + try: + await heartbeat_response(response_id, _POD_ID) + beats += 1 + # Sampled heartbeat log so the lifecycle is visible + # without spamming every interval. Every 5th (~15s + # at 3s interval) is a good compromise. + if beats % 5 == 1: + logger.info( + "[durable] heartbeat beat#%d response_id=%s pod=%s", + beats, + response_id, + _POD_ID, + ) + except Exception: + logger.warning( + "[durable] heartbeat write failed response_id=%s; will retry", + response_id, + exc_info=True, + ) + try: + await asyncio.wait_for(stop.wait(), timeout=interval) + except TimeoutError: + pass + except asyncio.CancelledError: + pass + logger.info( + "[durable] heartbeat stop response_id=%s pod=%s total_beats=%d", + response_id, + _POD_ID, + beats, + ) + + hb_task = asyncio.create_task(_beat(), name=f"heartbeat-{response_id}") + try: + yield + finally: + stop.set() + hb_task.cancel() + try: + await hb_task + except (asyncio.CancelledError, Exception): + pass + @asynccontextmanager async def _task_scope( self, response_id: str, state: dict[str, Any] @@ -348,7 +749,8 @@ async def _task_scope( # TODO: sequence number computation is racy (see _deferred_mark_failed). async with asyncio.timeout(self._settings.cleanup_timeout_seconds): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -361,6 +763,7 @@ async def _task_scope( "code": "task_failed", }, }, + attempt_number=attempt, ) await update_response_status(response_id, "failed") except Exception: @@ -382,11 +785,19 @@ async def _run_background_stream( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the streaming agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_stream(response_id, request_data, return_trace_id, state) + async with self._task_scope(response_id, state), self._heartbeat(response_id): + await self._do_background_stream( + response_id, + request_data, + return_trace_id, + state, + attempt_number=attempt_number, + ) def transform_stream_event(self, event: dict, response_id: str) -> dict: """Override to transform events before persistence (e.g. replace placeholder IDs).""" @@ -398,6 +809,8 @@ async def _do_background_stream( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via stream_fn, persist each stream event as a message row.""" stream_fn = get_stream_function() @@ -406,8 +819,23 @@ async def _do_background_stream( raise RuntimeError("No stream function registered; cannot run background stream") func_name = stream_fn.__name__ + logger.info( + "[durable] background stream start response_id=%s attempt=%d pod=%s handler=%s", + response_id, + attempt_number, + _POD_ID, + func_name, + ) all_chunks: list[dict[str, Any]] = [] - seq = 0 + # Continue sequence numbering across attempts so the client's cursor + # never rewinds on resume. First attempt starts at 0 and skips the DB + # lookup — keeps the fast path identical to pre-resume behavior and + # avoids an extra query per background request. + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + seq = 0 with mlflow.start_span(name=func_name) as span: span.set_inputs(request_data) @@ -420,16 +848,27 @@ async def _do_background_stream( evt_type = evt.get("type", "message") logger.debug( "SSE event (background)", - extra={"response_id": response_id, "seq": seq, "type": evt_type}, + extra={ + "response_id": response_id, + "seq": seq, + "type": evt_type, + "attempt": attempt_number, + }, ) await append_message( response_id, seq, item=json.dumps(item) if item is not None else None, stream_event=evt, + attempt_number=attempt_number, ) seq += 1 state["seq"] = seq + # Explicit yield so task.cancel() propagates promptly on + # tight event streams. The OpenAI Agents Runner's + # stream_events() awaits a queue that empties fast enough + # that cancellation can sit for tens of seconds without this. + await asyncio.sleep(0) span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "openai") span.set_outputs(ResponsesAgent.responses_agent_output_reducer(all_chunks)) @@ -439,12 +878,17 @@ async def _do_background_stream( response_id, seq, stream_event={"trace_id": span.trace_id}, + attempt_number=attempt_number, ) await update_response_status(response_id, "completed") - logger.debug( - "Background stream completed", - extra={"response_id": response_id, "total_events": seq}, + logger.info( + "[durable] background stream completed response_id=%s attempt=%d " + "total_events=%d pod=%s", + response_id, + attempt_number, + seq, + _POD_ID, ) async def _run_background_invoke( @@ -452,11 +896,19 @@ async def _run_background_invoke( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the invoke agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_invoke(response_id, request_data, return_trace_id, state) + async with self._task_scope(response_id, state), self._heartbeat(response_id): + await self._do_background_invoke( + response_id, + request_data, + return_trace_id, + state, + attempt_number=attempt_number, + ) async def _do_background_invoke( self, @@ -464,6 +916,8 @@ async def _do_background_invoke( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via invoke_fn, persist each output item as a message row.""" invoke_fn = get_invoke_function() @@ -485,19 +939,27 @@ async def _do_background_invoke( span.set_outputs(result) output = result.get("output", []) + # Continue sequence numbering across attempts (see _do_background_stream). + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + base_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + base_seq = 0 for i, item in enumerate(output): item_dict = ( item if isinstance(item, dict) else (item.model_dump() if hasattr(item, "model_dump") else {"content": str(item)}) ) + seq = base_seq + i await append_message( response_id, - i, + seq, item=json.dumps(item_dict), stream_event={"type": "response.output_item.done", "item": item_dict}, + attempt_number=attempt_number, ) - state["seq"] = i + 1 + state["seq"] = seq + 1 if return_trace_id: await update_response_trace_id(response_id, span.trace_id) await update_response_status(response_id, "completed") @@ -506,6 +968,157 @@ async def _do_background_invoke( extra={"response_id": response_id, "output_items": len(output)}, ) + async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: + """If ``resp`` is a stale in-progress run, attempt an atomic claim. + + On success, kick off a new background task that re-invokes the handler + on a rotated conversation anchor with the replayed input enriched by + the prior attempt's emitted items, and returns the new + ``attempt_number``. On failure (another pod won, or the run is no + longer stale), returns ``None``. + + This is the lazy resume path: triggered by a client retrieve. Pods + don't poll for stale work proactively in v1 — if no client ever calls + ``GET /responses/{id}``, the task_timeout sweep eventually marks it + failed. + """ + if resp.status != "in_progress": + return None + # The run may be freshly started but too young to have a heartbeat yet; + # respect the creation age as a grace period equal to the stale + # threshold. Otherwise a quick follow-up retrieve could hijack a + # running pod before it ever writes its first heartbeat. + if resp.heartbeat_at is None: + age = _age_seconds(resp.created_at) + if age < self._settings.heartbeat_stale_threshold_seconds: + logger.debug( + "[durable] claim skipped response_id=%s reason=grace_period " + "age=%.1fs threshold=%.1fs", + response_id, + age, + self._settings.heartbeat_stale_threshold_seconds, + ) + return None + else: + hb_age = _age_seconds(resp.heartbeat_at) + if hb_age < self._settings.heartbeat_stale_threshold_seconds: + # Heartbeat is fresh — owner is alive. Common case, keep + # quiet at debug so we don't spam every poll iteration. + logger.debug( + "[durable] claim skipped response_id=%s reason=heartbeat_fresh " + "age=%.1fs threshold=%.1fs", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + ) + return None + logger.info( + "[durable] stale heartbeat detected response_id=%s " + "heartbeat_age=%.1fs threshold=%.1fs current_owner=%s", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + resp.owner_pod_id, + ) + if resp.original_request is None: + # Nothing to replay from — the run predates durability metadata. + logger.warning( + "[durable] cannot resume response_id=%s reason=no_original_request", + response_id, + ) + return None + + logger.info( + "[durable] attempting claim response_id=%s current_attempt=%d new_owner=%s", + response_id, + resp.attempt_number, + _POD_ID, + ) + new_attempt = await claim_stale_response( + response_id, + new_owner_pod_id=_POD_ID, + stale_threshold_seconds=self._settings.heartbeat_stale_threshold_seconds, + ) + if new_attempt is None: + # Someone else owns it, or the row was updated between the read and + # the claim. Expected under contention. + logger.info( + "[durable] claim lost response_id=%s (another pod won or row changed)", + response_id, + ) + return None + + # Build a "resume" request by REPLAYING the original POST's input on a + # ROTATED conversation anchor, enriched with the prior attempt's + # already-emitted tool events: + # + # 1. Carry forward the prior attempt's function_call / function_call_output + # items so the LLM sees what's already been done. Without this, + # attempt N+1 re-plans from just the user's latest message and + # re-emits tool calls that previously completed (e.g. it re-runs + # get_time even though only deep_research was interrupted). The + # interrupted tool's orphan function_call gets a synthetic + # "interrupted" output via the sanitizer below. + # + # 2. Rotate conversation_id so the handler's SDK helpers resolve to a + # FRESH thread_id / session_id for this attempt. Without this, the + # handler would reload the crashed attempt's mid-turn checkpoint, + # which on LangGraph produces a stream-event orphan artifact at the + # attempt boundary (rotation-findings.md stress test). Attempt N+1 + # runs on a clean checkpointer; the prior-attempt tool events in + # input[] are the single source of truth for what already ran. + existing = await get_messages(response_id, after_sequence=None) + next_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + prior_tool_events = _collect_prior_attempt_tool_events( + existing, prior_attempt_number=new_attempt - 1 + ) + + resume_dict = copy.deepcopy(resp.original_request) + if prior_tool_events: + resume_input = list(resume_dict.get("input") or []) + resume_input.extend(prior_tool_events) + resume_dict["input"] = resume_input + logger.info( + "[durable] resume inherited %d tool-event item(s) from attempt %d response_id=%s", + len(prior_tool_events), + new_attempt - 1, + response_id, + ) + if self._settings.auto_sanitize_input: + resume_dict = _sanitize_request_input(resume_dict) + resume_dict = _rotate_conversation_id(resume_dict, new_attempt, response_id) + resume_request = self.validator.validate_and_convert_request(resume_dict) + await append_message( + response_id, + next_seq, + stream_event={ + "type": "response.resumed", + "attempt": new_attempt, + "from_seq": next_seq, + }, + attempt_number=new_attempt, + ) + + logger.info( + "[durable] claim succeeded response_id=%s new_attempt=%d pod=%s resume_from_seq=%d", + response_id, + new_attempt, + _POD_ID, + next_seq, + ) + + task = asyncio.create_task( + self._run_background_stream( + response_id, + resume_request, + return_trace_id=False, + attempt_number=new_attempt, + ), + name=f"resume-{response_id}-{new_attempt}", + ) + self._track_task(response_id, task) + return new_attempt + async def _handle_retrieve_request( self, response_id: str, @@ -523,7 +1136,20 @@ async def _handle_retrieve_request( if resp is None: raise HTTPException(status_code=404, detail="Response not found") - _, status, created_at, trace_id = resp + # Try a lazy resume before falling back to the absolute-timeout sweep. + # This gives us crash-recovery semantics: an idle client reconnecting + # after a pod died will reclaim the run and resume it here instead of + # just marking it failed. + await self._try_claim_and_resume(response_id, resp) + + # Refresh after the potential resume: status / attempt_number may have changed. + resp = await get_response(response_id) + if resp is None: + raise HTTPException(status_code=404, detail="Response not found") + + status = resp.status + created_at = resp.created_at + trace_id = resp.trace_id if ( status == "in_progress" @@ -542,10 +1168,9 @@ async def _handle_retrieve_request( }, ) # TODO: sequence number computation here is racy under concurrent writers. - # Acceptable at current scale; for high-QPS use a DB-assigned sequence or - # SELECT FOR UPDATE on the response row to serialise writers. existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -558,6 +1183,7 @@ async def _handle_retrieve_request( "code": "task_timeout", }, }, + attempt_number=attempt, ) status = "failed" @@ -579,25 +1205,45 @@ async def _handle_retrieve_request( messages = await get_messages(response_id, after_sequence=None) if not messages and status == "in_progress": - return {"id": response_id, "status": "in_progress"} + return { + "id": response_id, + "status": "in_progress", + "attempt_number": resp.attempt_number, + } if status == "completed" and messages: + # Only consider items from the final (successful) attempt so that + # abandoned in-progress items from crashed attempts don't leak + # into the authoritative response body. Completed output_item.done + # events across attempts together make up the conversation — the + # agent SDK's checkpointer guarantees done-items are not re-emitted + # by later attempts, so this is a union with no duplicates. output = [] - for _, _, evt in messages: - if evt and "item" in evt: - output.append(evt["item"]) + for _, _, evt, _attempt in messages: + if evt and evt.get("type") == "response.output_item.done": + output.append(evt.get("item")) result: dict[str, Any] = { "id": response_id, "status": "completed", - "output": output, + "output": [o for o in output if o is not None], + "attempt_number": resp.attempt_number, } if trace_id: result["metadata"] = {"trace_id": trace_id} return result if status == "failed" and messages: - for _, _, evt in messages: + for _, _, evt, _attempt in messages: if evt and evt.get("type") == "error": - return {"id": response_id, "status": "failed", "error": evt.get("error")} - return {"id": response_id, "status": status} + return { + "id": response_id, + "status": "failed", + "error": evt.get("error"), + "attempt_number": resp.attempt_number, + } + return { + "id": response_id, + "status": status, + "attempt_number": resp.attempt_number, + } async def _stream_retrieve( self, @@ -638,15 +1284,26 @@ async def _stream_retrieve( ) break - _, status, _, _ = resp + status = resp.status + # Self-heal: if this response is still in_progress but its owning + # pod has gone silent past heartbeat_stale_threshold, try to claim + # + resume on this pod. A no-op if heartbeat is fresh or another + # pod already won. Without this, a stream opened before the crash + # would idle forever polling a dead run — since _try_claim_and_resume + # is only triggered by the outer retrieve handler on fresh GETs. + if status == "in_progress": + await self._try_claim_and_resume(response_id, resp) + # starting_after=0 fetches all messages (sequence numbers start at 0). # We use after_sequence=-1 for the DB query so that seq 0 is included. after_seq = last_seq - 1 if last_seq == 0 else last_seq messages = await get_messages(response_id, after_sequence=after_seq) - for seq, _, evt in messages: + for seq, _, evt, _attempt in messages: if evt is not None: - evt = {**evt, "sequence_number": seq} + # Tag every SSE frame with the response_id so proxies / + # clients can discover it without parsing nested fields. + evt = {**evt, "sequence_number": seq, "response_id": response_id} event_type = evt.get("type", "message") logger.debug( "SSE event", diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py index 7b646116..4e41b715 100644 --- a/src/databricks_ai_bridge/long_running/settings.py +++ b/src/databricks_ai_bridge/long_running/settings.py @@ -15,6 +15,12 @@ class LongRunningSettings: poll_interval_seconds: float = 1.0 db_statement_timeout_ms: int = 5000 cleanup_timeout_seconds: float = 7.0 + heartbeat_interval_seconds: float = 3.0 + heartbeat_stale_threshold_seconds: float = 10.0 + # Walk request.input[] on every request and drop/repair orphaned + # function_call / function_call_output pairs before the handler runs. + # Lets handlers stay framework-idiomatic without carrying repair logic. + auto_sanitize_input: bool = True def __post_init__(self) -> None: if self.task_timeout_seconds <= 0: @@ -25,6 +31,16 @@ def __post_init__(self) -> None: raise ValueError("db_statement_timeout_ms must be positive") if self.cleanup_timeout_seconds <= 0: raise ValueError("cleanup_timeout_seconds must be positive") + if self.heartbeat_interval_seconds <= 0: + raise ValueError("heartbeat_interval_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= 0: + raise ValueError("heartbeat_stale_threshold_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= self.heartbeat_interval_seconds: + raise ValueError( + f"heartbeat_stale_threshold_seconds ({self.heartbeat_stale_threshold_seconds}) " + f"must be strictly greater than heartbeat_interval_seconds " + f"({self.heartbeat_interval_seconds}) to avoid false stale-run detection." + ) db_timeout_s = self.db_statement_timeout_ms / 1000.0 if self.cleanup_timeout_seconds <= db_timeout_s: raise ValueError( diff --git a/src/databricks_ai_bridge/tool_repair.py b/src/databricks_ai_bridge/tool_repair.py new file mode 100644 index 00000000..99a7e43c --- /dev/null +++ b/src/databricks_ai_bridge/tool_repair.py @@ -0,0 +1,149 @@ +"""Shared orphan-tool-call repair logic. + +``sanitize_tool_items`` walks a list of Responses-API-style items and +reconciles orphan / duplicate ``function_call`` / ``function_call_output`` +items. Used by: + +* the server-side input sanitizer in :mod:`...long_running.server`, which + runs on every request before the handler is invoked; and +* the OpenAI :class:`AsyncDatabricksSession` ``get_items`` auto-repair, + which returns protocol-valid items without touching the underlying DB. + +The LangChain checkpointer has its own repair path +(``_build_tool_resume_repair``) that operates on ``AIMessage`` / +``ToolMessage`` shapes rather than the dict items here. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable + +logger = logging.getLogger(__name__) + +#: Default body for the synthetic ``function_call_output`` injected when a +#: prior attempt's tool call has no matching output (e.g. the pod was killed +#: between emitting the call and its result). Shared between the server-side +#: input sanitizer and integration-side read-time repair paths so the user- +#: visible text stays consistent across the durable-resume contract. +DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT = ( + "[INTERRUPTED] This tool call did not complete due to a server " + "interruption, so no result is available. Other tool calls in the " + "conversation history completed normally and their results remain valid. " + "If the information is still needed, re-invoking only this specific tool " + "is usually sufficient." +) + + +def _default_item_get(item: Any, key: str) -> Any: + if isinstance(item, dict): + return item.get(key) + return getattr(item, key, None) + + +def sanitize_tool_items( + items: list[Any], + synthetic_output: str = DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, + *, + item_get: Callable[[Any, str], Any] = _default_item_get, + log_prefix: str = "[durable] items sanitized", +) -> list[Any]: + """Return a protocol-valid view of ``items``. + + In order: + + * drops duplicate ``function_call`` items by ``call_id``; + * drops duplicate or orphan ``function_call_output`` items (no matching + ``function_call`` anywhere in the list); + * injects a synthetic ``function_call_output`` immediately after any + ``function_call`` that has no output in the list. + + Also recognises chat-completions-shape ``{role: assistant, tool_calls: + [...]}`` items as declaring call_ids, so mixed-shape histories don't + trip the orphan check. + + Returns the caller's ``items`` reference unchanged on the happy path so + downstream can skip any re-persistence cheaply. + + The ``synthetic_output`` text is passed in by the caller — each caller + owns its own copy of the string so product decisions about wording + stay scoped to the durable-resume path they belong to. + + ``item_get`` lets session-style objects (ORM rows with attribute + access) reuse this walker; defaults to plain dict ``.get``. + """ + if not items: + return items + + declared_call_ids: set[str] = set() + call_ids_with_output: set[str] = set() + for item in items: + t = item_get(item, "type") + cid = item_get(item, "call_id") + if t == "function_call" and cid: + declared_call_ids.add(cid) + if t == "function_call_output" and cid: + call_ids_with_output.add(cid) + # Chat-completions shape: assistant message with tool_calls. + if item_get(item, "role") == "assistant": + tool_calls = item_get(item, "tool_calls") + if isinstance(tool_calls, list): + for tc in tool_calls: + if not isinstance(tc, dict): + continue + tc_id = tc.get("id") or (tc.get("function") or {}).get("id") + if tc_id: + declared_call_ids.add(tc_id) + + sanitized: list[Any] = [] + seen_calls: set[str] = set() + seen_outputs: set[str] = set() + injected = 0 + dropped_orphan_outputs = 0 + dropped_duplicates = 0 + + for item in items: + t = item_get(item, "type") + cid = item_get(item, "call_id") + if t == "function_call" and cid: + if cid in seen_calls: + dropped_duplicates += 1 + continue + seen_calls.add(cid) + sanitized.append(item) + if cid not in call_ids_with_output: + sanitized.append( + { + "type": "function_call_output", + "call_id": cid, + "output": synthetic_output, + } + ) + injected += 1 + elif t == "function_call_output" and cid: + if cid in seen_outputs: + dropped_duplicates += 1 + continue + if cid not in declared_call_ids: + dropped_orphan_outputs += 1 + continue + seen_outputs.add(cid) + sanitized.append(item) + else: + sanitized.append(item) + + if not (injected or dropped_orphan_outputs or dropped_duplicates): + # Happy path: hand back the original list so callers can skip + # re-persistence by identity comparison (``sanitized is items``). + return items + + logger.info( + "%s: injected=%d dropped_orphan_outputs=%d dropped_duplicates=%d original=%d final=%d", + log_prefix, + injected, + dropped_orphan_outputs, + dropped_duplicates, + len(items), + len(sanitized), + ) + return sanitized diff --git a/tests/databricks_ai_bridge/test_long_running_db.py b/tests/databricks_ai_bridge/test_long_running_db.py index a1290ba1..d425da44 100644 --- a/tests/databricks_ai_bridge/test_long_running_db.py +++ b/tests/databricks_ai_bridge/test_long_running_db.py @@ -160,10 +160,12 @@ async def test_get_messages(mock_session): result_mock.scalars.return_value.all.return_value = [msg1, msg2] mock_session.execute.return_value = result_mock + msg1.attempt_number = 1 + msg2.attempt_number = 1 messages = await get_messages("resp_abc123", after_sequence=None) assert len(messages) == 2 - assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}) - assert messages[1] == (1, None, None) + assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}, 1) + assert messages[1] == (1, None, None, 1) @pytest.mark.asyncio @@ -175,6 +177,10 @@ async def test_get_response(mock_session): row.created_at = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) row.trace_id = "trace_xyz" + row.owner_pod_id = None + row.heartbeat_at = None + row.attempt_number = 1 + row.original_request = None result_mock = MagicMock() result_mock.scalar_one_or_none.return_value = row mock_session.execute.return_value = result_mock @@ -185,6 +191,10 @@ async def test_get_response(mock_session): "completed", datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc), "trace_xyz", + None, # owner_pod_id + None, # heartbeat_at + 1, # attempt_number + None, # original_request ) @@ -283,9 +293,16 @@ async def test_creates_schema_and_tables(self, reset_db_globals): with patch(f"{DB_MODULE}.AsyncLakebaseSQLAlchemy", mock_cls), patch(f"{DB_MODULE}.event"): await init_db(autoscaling_endpoint="ep") - mock_conn.execute.assert_awaited_once() - sql_arg = str(mock_conn.execute.call_args[0][0]) - assert "CREATE SCHEMA IF NOT EXISTS" in sql_arg + # init_db runs: CREATE SCHEMA + run_sync(create_all) + a series of + # ADD COLUMN IF NOT EXISTS / CREATE INDEX IF NOT EXISTS to migrate + # the durability columns onto pre-existing tables. + all_sql = " | ".join(str(call.args[0]) for call in mock_conn.execute.call_args_list) + assert "CREATE SCHEMA IF NOT EXISTS" in all_sql + assert "ADD COLUMN IF NOT EXISTS owner_pod_id" in all_sql + assert "ADD COLUMN IF NOT EXISTS heartbeat_at" in all_sql + assert "ADD COLUMN IF NOT EXISTS attempt_number" in all_sql + assert "ADD COLUMN IF NOT EXISTS original_request" in all_sql + assert "idx_responses_stale" in all_sql mock_conn.run_sync.assert_awaited_once() @pytest.mark.asyncio @@ -346,3 +363,110 @@ async def fake_factory(): monkeypatch.setattr(db_mod, "_session_factory", fake_factory) async with session_scope() as session: assert session is mock_session + + +# --------------------------------------------------------------------------- +# Durability metadata: owner_pod_id, heartbeat, claim, attempt_number +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_response_with_owner_and_original_request(mock_session): + """New background callers stamp pod id + serialized request on creation — + without these, a resumed pod can't re-invoke the handler.""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response( + "resp_abc", + "in_progress", + owner_pod_id="pod-1", + original_request={"input": [{"role": "user", "content": "hi"}]}, + ) + added = mock_session.add.call_args[0][0] + assert added.owner_pod_id == "pod-1" + assert added.heartbeat_at is not None + # original_request is JSON-encoded for Text storage. + assert '"role": "user"' in added.original_request + + +@pytest.mark.asyncio +async def test_create_response_without_durability_metadata(mock_session): + """Legacy/no-durability callers should still work and write no + owner/heartbeat (so the stale sweep can't accidentally claim them).""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response("resp_x", "in_progress") + added = mock_session.add.call_args[0][0] + assert added.owner_pod_id is None + assert added.heartbeat_at is None + assert added.original_request is None + + +@pytest.mark.asyncio +async def test_heartbeat_response_updates_timestamp(mock_session): + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 1 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", "pod-1") + assert ok is True + mock_session.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_heartbeat_response_fails_when_not_owner(mock_session): + """If the CAS misses (owner changed / row deleted), heartbeat reports + failure so the caller can stop looping.""" + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 0 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", "pod-1") + assert ok is False + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_attempt_number(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + row = MagicMock() + row.__iter__ = lambda self: iter([2]) + row.__getitem__ = lambda self, i: 2 + result_mock = MagicMock() + result_mock.first.return_value = row + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response( + "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 + ) + assert attempt == 2 + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_none_when_not_eligible(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + result_mock = MagicMock() + result_mock.first.return_value = None + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response( + "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 + ) + assert attempt is None + + +@pytest.mark.asyncio +async def test_append_message_with_attempt_number(mock_session): + """Resumed events must be tagged with the resume attempt so retrieve can + filter or the client can render the response.resumed boundary cleanly.""" + from databricks_ai_bridge.long_running.repository import append_message + + await append_message("resp_abc", 5, stream_event={"x": 1}, attempt_number=3) + added = mock_session.add.call_args[0][0] + assert added.attempt_number == 3 + assert added.sequence_number == 5 diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 27b6eaf5..a8461379 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -14,12 +14,18 @@ pytest.importorskip("fastapi") pytest.importorskip("psycopg") +from databricks_ai_bridge.long_running.repository import ResponseInfo from databricks_ai_bridge.long_running.server import ( LongRunningAgentServer, + _collect_prior_attempt_tool_events, _deferred_mark_failed, + _inject_conversation_id, + _rotate_conversation_id, + _sanitize_request_input, _sse_event, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings +from databricks_ai_bridge.tool_repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT # --------------------------------------------------------------------------- # Shared helpers @@ -34,6 +40,40 @@ def _make_server(**kwargs): return LongRunningAgentServer("ResponsesAgent", **kwargs) +def _resp_info( + response_id: str = "resp_123", + status: str = "in_progress", + created_at=None, + trace_id: str | None = None, + owner_pod_id: str | None = None, + heartbeat_at=None, + attempt_number: int = 1, + original_request: dict | None = None, +) -> ResponseInfo: + """Build a ResponseInfo with sensible defaults for tests. + + Mirrors the server's repository model so test setups stay terse even as + durability columns grow over time. + """ + if created_at is None: + created_at = datetime.now(timezone.utc) + return ResponseInfo( + response_id=response_id, + status=status, + created_at=created_at, + trace_id=trace_id, + owner_pod_id=owner_pod_id, + heartbeat_at=heartbeat_at, + attempt_number=attempt_number, + original_request=original_request, + ) + + +def _msg(seq: int, item=None, evt=None, attempt: int = 1): + """Build a (seq, item, stream_event, attempt_number) tuple for get_messages mocks.""" + return (seq, item, evt, attempt) + + def _mock_span(): """Return a mock MLflow span with the attributes the server uses.""" span = MagicMock() @@ -55,8 +95,8 @@ def _mock_validator(server): class TestSSEEvent: def test_dict_data(self): result = _sse_event("response.created", {"id": "resp_123", "status": "in_progress"}) - assert result.startswith("event: response.created\n") - assert "data: " in result + assert result.startswith("data: ") + assert "event:" not in result assert result.endswith("\n\n") data_line = result.split("data: ")[1].strip() parsed = json.loads(data_line) @@ -64,8 +104,8 @@ def test_dict_data(self): def test_string_data(self): result = _sse_event("error", "something went wrong") - assert "event: error\n" in result - assert "data: something went wrong\n\n" in result + assert "event:" not in result + assert result == "data: something went wrong\n\n" class TestLongRunningSettings: @@ -189,7 +229,7 @@ def test_starting_after_zero_without_stream_is_allowed(self): patch( f"{MODULE}.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( f"{MODULE}.get_messages", @@ -209,8 +249,13 @@ async def test_marks_response_failed(self): patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, - return_value=[(0, None, {"type": "response.created"})], + return_value=[_msg(0, None, {"type": "response.created"})], ) as mock_get, + patch( + "databricks_ai_bridge.long_running.server.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch( "databricks_ai_bridge.long_running.server.append_message", new_callable=AsyncMock, @@ -271,13 +316,13 @@ async def test_completed_returns_output(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), "trace_abc"), + return_value=_resp_info("resp_123", "completed", trace_id="trace_abc"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - ( + _msg( 0, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -305,7 +350,7 @@ async def test_stale_run_detection(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_stale", "in_progress", old_time, None), + return_value=_resp_info("resp_stale", "in_progress", created_at=old_time), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -336,7 +381,7 @@ async def test_in_progress_returns_status(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -347,7 +392,11 @@ async def test_in_progress_returns_status(self): result = await server._handle_retrieve_request( "resp_123", stream=False, starting_after=0 ) - assert result == {"id": "resp_123", "status": "in_progress"} + assert result == { + "id": "resp_123", + "status": "in_progress", + "attempt_number": 1, + } class TestStreamRetrieve: @@ -360,14 +409,14 @@ async def test_completed_stream(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "completed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created", "id": "resp_123"}), - ( + _msg(0, None, {"type": "response.created", "id": "resp_123"}), + _msg( 1, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -394,13 +443,13 @@ async def test_failed_stream_stops(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "failed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "failed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "error", "error": {"message": "boom"}}), + _msg(0, None, {"type": "error", "error": {"message": "boom"}}), ], ), ): @@ -668,10 +717,15 @@ async def test_exception_writes_error_event_inline(self): f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created"}), - (1, None, {"type": "response.output_text.delta"}), + _msg(0, None, {"type": "response.created"}), + _msg(1, None, {"type": "response.output_text.delta"}), ], ), + patch( + f"{MODULE}.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, ): @@ -875,3 +929,801 @@ async def test_lifespan_not_set_when_db_not_configured(self): routes = [r.path for r in server.app.routes if hasattr(r, "path")] assert "/responses/{response_id}" in routes mock_init.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Durable resume: claim/heartbeat/attempt_number/sentinel +# --------------------------------------------------------------------------- + + +class TestSanitizeRequestInput: + """Full-history orphan walker — catches mid-history orphans that neither + the LangGraph middleware (trailing-only) nor session.repair() (session-only) + cover. See rotation-findings.md Test E.""" + + def test_empty_input_is_noop(self): + assert _sanitize_request_input({}) == {} + assert _sanitize_request_input({"input": []}) == {"input": []} + + def test_passes_through_paired_call_and_output(self): + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "c1", "output": "ok"}, + {"role": "assistant", "content": "done"}, + ] + out = _sanitize_request_input({"input": list(items)}) + assert out["input"] == items + + def test_injects_synthetic_output_for_trailing_orphan_call(self): + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + ] + out = _sanitize_request_input({"input": items}) + assert len(out["input"]) == 3 + assert out["input"][2] == { + "type": "function_call_output", + "call_id": "c1", + "output": DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, + } + + def test_injects_synthetic_output_for_midhistory_orphan_call(self): + # The case that today's middleware misses (Test E). + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"role": "user", "content": "different question"}, + ] + out = _sanitize_request_input({"input": items}) + assert len(out["input"]) == 4 + assert out["input"][1]["type"] == "function_call" + assert out["input"][2] == { + "type": "function_call_output", + "call_id": "c1", + "output": DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, + } + assert out["input"][3] == {"role": "user", "content": "different question"} + + def test_drops_orphan_output_with_no_matching_call(self): + # The LangGraph stream-event attempt-boundary artifact. + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call_output", "call_id": "c-ghost", "output": "x"}, + {"role": "user", "content": "follow-up"}, + ] + out = _sanitize_request_input({"input": items}) + assert out["input"] == [ + {"role": "user", "content": "hi"}, + {"role": "user", "content": "follow-up"}, + ] + + def test_dedupes_duplicate_calls_and_outputs(self): + items = [ + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "c1", "output": "ok"}, + {"type": "function_call_output", "call_id": "c1", "output": "ok"}, + ] + out = _sanitize_request_input({"input": items}) + assert len(out["input"]) == 2 + assert out["input"][0]["type"] == "function_call" + assert out["input"][1]["type"] == "function_call_output" + + def test_recognizes_chat_completions_shape_as_declaring_call_id(self): + # An assistant message with tool_calls counts as "declaring" a call_id, + # so a matching function_call_output further down is NOT dropped. + items = [ + { + "role": "assistant", + "content": [], + "tool_calls": [ + {"id": "tc-1", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ], + }, + {"type": "function_call_output", "call_id": "tc-1", "output": "ok"}, + ] + out = _sanitize_request_input({"input": list(items)}) + assert out["input"] == items + + def test_non_dict_items_pass_through(self): + items = [{"role": "user", "content": "hi"}, "not-a-dict", 42] + out = _sanitize_request_input({"input": list(items)}) + assert out["input"] == items + + +class TestCollectPriorAttemptToolEvents: + """Gather function_call / function_call_output items emitted during a + prior attempt so the next attempt can inherit already-completed tool + results instead of re-running them from scratch.""" + + def _event(self, seq, attempt, item_type, call_id, output=None): + item = {"type": item_type, "call_id": call_id, "name": "f", "arguments": "{}"} + if output is not None: + item = {"type": item_type, "call_id": call_id, "output": output} + evt = {"type": "response.output_item.done", "item": item} + return (seq, None, evt, attempt) + + def test_filters_to_requested_prior_attempt(self): + messages = [ + self._event(0, 1, "function_call", "c1"), + self._event(1, 1, "function_call_output", "c1", output="ok"), + # attempt 2's events should not be returned when asking for attempt 1. + self._event(2, 2, "function_call", "c2"), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert [i["call_id"] for i in out] == ["c1", "c1"] + assert [i["type"] for i in out] == ["function_call", "function_call_output"] + + def test_only_output_item_done_events_count(self): + noise = ( + 0, + None, + {"type": "response.output_text.delta", "delta": "hi"}, + 1, + ) + messages = [noise, self._event(1, 1, "function_call", "c1")] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert len(out) == 1 + assert out[0]["call_id"] == "c1" + + def test_messages_hoisted_after_tool_pairs(self): + # Claude interleaves narrative `message` items between function_call + # and function_call_output in its event stream. Preserving that + # ordering in the replay would violate Anthropic's "tool_use + # immediately followed by tool_result" rule. Collector hoists all + # narrative messages to the end so tool pairs stay adjacent. + messages = [ + self._event(0, 1, "function_call", "c1"), + ( + 1, + None, + { + "type": "response.output_item.done", + "item": {"type": "message", "role": "assistant", "content": "step one"}, + }, + 1, + ), + self._event(2, 1, "function_call_output", "c1", output="ok"), + self._event(3, 1, "function_call", "c2"), + ( + 4, + None, + { + "type": "response.output_item.done", + "item": {"type": "message", "role": "assistant", "content": "step two"}, + }, + 1, + ), + self._event(5, 1, "function_call_output", "c2", output="ok"), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + # 2 pairs + 2 messages. + assert len(out) == 6 + assert [i["type"] for i in out] == [ + "function_call", + "function_call_output", + "function_call", + "function_call_output", + "message", + "message", + ] + # call_ids paired up (c1,c1,c2,c2) and narrative in event order. + assert out[0]["call_id"] == "c1" and out[1]["call_id"] == "c1" + assert out[2]["call_id"] == "c2" and out[3]["call_id"] == "c2" + + def test_reassembles_partial_text_from_delta_events(self): + # Attempt crashed mid-stream: item.added + deltas but no item.done. + # The collector should synthesize a message item from accumulated deltas + # so the next attempt sees where narration trailed off. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.added", + "item": {"type": "message", "id": "msg_1"}, + }, + 1, + ), + ( + 1, + None, + {"type": "response.output_text.delta", "item_id": "msg_1", "delta": "Hello, "}, + 1, + ), + ( + 2, + None, + {"type": "response.output_text.delta", "item_id": "msg_1", "delta": "world"}, + 1, + ), + # No item.done — crash. + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert len(out) == 1 + assert out[0]["type"] == "message" + assert out[0]["role"] == "assistant" + assert out[0]["content"][0]["text"] == "Hello, world" + + def test_ignores_partial_text_if_item_eventually_completed(self): + # Deltas streamed, then item.done landed — since completed message + # items are no longer inherited at all, and the partial reassembly + # only fires when .done is missing, this case produces an empty + # inherited list. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.added", + "item": {"type": "message", "id": "msg_1"}, + }, + 1, + ), + ( + 1, + None, + {"type": "response.output_text.delta", "item_id": "msg_1", "delta": "Hello"}, + 1, + ), + ( + 2, + None, + { + "type": "response.output_item.done", + "item": { + "type": "message", + "id": "msg_1", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello, world"}], + }, + }, + 1, + ), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + # Completed message inherits via the narrative bucket; partial + # reassembly cleared its tracker when .done arrived so it does NOT + # also synthesize a duplicate from the deltas. + assert len(out) == 1 + assert out[0]["type"] == "message" + assert out[0]["content"][0]["text"] == "Hello, world" + + def test_skips_unknown_item_types(self): + # Item types outside the allow-list (e.g., future event kinds like + # file_search_call / code_interpreter_call) are dropped — safer than + # forwarding them to the handler without review. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.done", + "item": {"type": "file_search_call", "results": []}, + }, + 1, + ) + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert out == [] + + +class TestRotateConversationId: + def test_rotate_drops_thread_id_and_sets_rotated_context(self): + r = {"custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "thread_id" not in out["custom_inputs"] + assert out["custom_inputs"]["user_id"] == "u" + assert out["context"]["conversation_id"] == "t1::attempt-2" + + def test_rotate_drops_session_id(self): + r = {"custom_inputs": {"session_id": "s1"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "session_id" not in out["custom_inputs"] + assert out["context"]["conversation_id"] == "s1::attempt-2" + + def test_rotate_falls_back_to_context_conversation_id(self): + r = {"custom_inputs": {}, "context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=3, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-3" + + def test_rotate_falls_back_to_response_id_as_last_resort(self): + r = {"custom_inputs": {}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "resp_x::attempt-2" + + def test_rotate_handles_missing_custom_inputs_key(self): + r = {"context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-2" + assert out["custom_inputs"] == {} + + +class TestInjectConversationId: + """Anchoring an otherwise-anonymous request to a response_id guarantees a + resumed run on a new pod resolves to the same agent-SDK thread/session.""" + + def test_injects_when_nothing_set(self): + r = {"input": [], "custom_inputs": {}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "resp_abc" + + def test_respects_existing_conversation_id(self): + r = {"input": [], "context": {"conversation_id": "user-set"}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "user-set" + + def test_respects_thread_id_from_custom_inputs(self): + r = {"input": [], "custom_inputs": {"thread_id": "t-1"}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + # When the client already pinned a thread, we don't overwrite — the + # template's _get_or_create_thread_id picks up custom_inputs first. + assert "conversation_id" not in (out["context"] or {}) + + def test_respects_session_id_from_custom_inputs(self): + r = {"input": [], "custom_inputs": {"session_id": "s-1"}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + assert "conversation_id" not in (out["context"] or {}) + + def test_handles_missing_context_key(self): + r = {"input": [], "custom_inputs": {}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "resp_abc" + + def test_does_not_mutate_input(self): + r = {"input": [], "custom_inputs": {}, "context": {}} + _inject_conversation_id(r, "resp_abc") + assert r["context"] == {} # original untouched + + +class TestHandleBackgroundRequestPersistsDurabilityState: + """Background request entry point should now stamp the response row with + the caller's pod, the original request body, and a conversation anchor.""" + + @pytest.mark.asyncio + async def test_persists_owner_and_original_request(self): + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + _mock_validator(server) + + captured: dict = {} + + async def fake_create_response( + response_id, status, *, owner_pod_id=None, original_request=None + ): + captured["response_id"] = response_id + captured["status"] = status + captured["owner_pod_id"] = owner_pod_id + captured["original_request"] = original_request + + with ( + patch(f"{MODULE}.create_response", side_effect=fake_create_response), + patch("asyncio.create_task") as mock_create_task, + ): + result = await server._handle_background_request( + {"input": [{"role": "user", "content": "hi"}]}, + is_streaming=False, + return_trace_id=False, + ) + + assert captured["status"] == "in_progress" + assert captured["owner_pod_id"] # non-empty + # original_request should include input + injected conversation_id. + orig = captured["original_request"] + assert orig["input"] == [{"role": "user", "content": "hi"}] + assert orig["context"]["conversation_id"] == captured["response_id"] + # Return shape: immediate response_obj, not a stream. + assert result["id"] == captured["response_id"] + assert result["status"] == "in_progress" + mock_create_task.assert_called_once() + + +class TestTryClaimAndResume: + @pytest.mark.asyncio + async def test_no_op_when_completed(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + resp = _resp_info(status="completed") + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_grace_period_for_fresh_run(self): + """Just-started runs get a grace window before they're claim-eligible.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", heartbeat_stale_threshold_seconds=15.0 + ) + # created 2s ago, no heartbeat yet → should NOT be claimed. + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=2), + heartbeat_at=None, + original_request={"input": []}, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_op_without_original_request(self): + """Legacy rows created before durability metadata can't be resumed.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=None, + original_request=None, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_claim_fails_returns_none(self): + """Another pod won the race — we quietly step aside.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=300), + original_request={"input": [{"role": "user"}]}, + ) + with ( + patch( + f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=None + ) as mock_claim, + patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, + ): + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_awaited_once() + mock_append.assert_not_awaited() + + @pytest.mark.asyncio + async def test_successful_claim_spawns_resume_and_emits_sentinel(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"user_id": "u"}, + "context": {"conversation_id": "resp_x"}, + }, + ) + captured: dict = {} + + async def fake_append(response_id, seq, *, item=None, stream_event=None, attempt_number=1): + captured["seq"] = seq + captured["event"] = stream_event + captured["attempt_tag"] = attempt_number + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch( + f"{MODULE}.get_messages", + new_callable=AsyncMock, + return_value=[_msg(0, None, {}), _msg(1, None, {})], + ), + patch(f"{MODULE}.append_message", side_effect=fake_append), + patch("asyncio.create_task") as mock_create_task, + ): + attempt = await server._try_claim_and_resume("resp_x", resp) + + assert attempt == 2 + # Sentinel is written at next_seq (existing seqs were 0 and 1). + assert captured["seq"] == 2 + assert captured["event"]["type"] == "response.resumed" + assert captured["event"]["attempt"] == 2 + assert captured["attempt_tag"] == 2 + # A resume task is spawned; it was not awaited synchronously. + mock_create_task.assert_called_once() + + @pytest.mark.asyncio + async def test_resume_replays_input_and_rotates_conversation_id(self): + """Resume must replay original_request.input (not blank it) and rotate + the conversation anchor so the handler resolves to a fresh thread / + session for the new attempt. Prevents the LangGraph stream-event + attempt-boundary orphan artifact (rotation-findings.md).""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"thread_id": "t1", "user_id": "u"}, + "context": {}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) + # Input is REPLAYED (not blanked) — the LLM sees full pre-crash history. + # The MLflow validator normalizes the shape (adds "type": "message" etc.) + # so compare essentials. + assert len(dumped["input"]) == 1 + assert dumped["input"][0]["role"] == "user" + assert dumped["input"][0]["content"] == "hi" + # thread_id was dropped so the handler's priority-2 fallback wins. + assert "thread_id" not in (dumped["custom_inputs"] or {}) + # Other custom_inputs keys are preserved. + assert dumped["custom_inputs"]["user_id"] == "u" + # conversation_id is rotated to a per-attempt value anchored on t1. + assert dumped["context"]["conversation_id"] == "t1::attempt-2" + assert kwargs.get("attempt_number") == 2 + + @pytest.mark.asyncio + async def test_resume_rotation_anchors_on_context_conversation_id(self): + """When the client didn't pin a thread_id/session_id, rotation uses + the injected context.conversation_id as the base anchor.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {}, + "context": {"conversation_id": "resp_x"}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=3), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) + # Rotation anchors on the stored context.conversation_id (priority 2). + # Note: re-rotating in a subsequent attempt would re-anchor on the + # ORIGINAL stored value, not the previous rotation — no stacking. + assert dumped["context"]["conversation_id"] == "resp_x::attempt-3" + + +class TestRetrieveTriggersLazyClaim: + @pytest.mark.asyncio + async def test_retrieve_calls_try_claim(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + + resp = _resp_info("resp_x", "in_progress") + with ( + patch(f"{MODULE}.get_response", new_callable=AsyncMock, return_value=resp), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch.object( + server, "_try_claim_and_resume", new_callable=AsyncMock, return_value=None + ) as mock_claim, + ): + await server._handle_retrieve_request("resp_x", stream=False, starting_after=0) + + mock_claim.assert_awaited_once() + + +class TestHeartbeatContextManager: + @pytest.mark.asyncio + async def test_writes_heartbeat_periodically(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x"): + await asyncio.sleep(0.2) # enough time for 2+ heartbeats + + # Heartbeat interval is 0.05s so we should see at least 2 writes. + assert mock_hb.await_count >= 2 + for call in mock_hb.await_args_list: + assert call.args[0] == "resp_x" + + @pytest.mark.asyncio + async def test_stops_cleanly_on_exit(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x"): + pass # immediate exit + + # Give the heartbeat loop a chance to observe the stop signal. + await asyncio.sleep(0.1) + writes_after_exit = mock_hb.await_count + + await asyncio.sleep(0.15) + # No new writes after the scope closed. + assert mock_hb.await_count == writes_after_exit + + @pytest.mark.asyncio + async def test_db_error_does_not_interrupt_body(self): + """Heartbeat failures are logged, not raised — the stale check catches + real death, so a transient write miss must not kill a live run.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + body_ran = False + with patch( + f"{MODULE}.heartbeat_response", + new_callable=AsyncMock, + side_effect=RuntimeError("db down"), + ): + async with server._heartbeat("resp_x"): + await asyncio.sleep(0.1) + body_ran = True + assert body_ran + + +class TestSettingsHeartbeatValidation: + def test_stale_must_exceed_interval(self): + with pytest.raises(ValueError, match="heartbeat_stale_threshold_seconds"): + LongRunningSettings( + heartbeat_interval_seconds=5.0, + heartbeat_stale_threshold_seconds=5.0, + ) + + def test_interval_must_be_positive(self): + with pytest.raises(ValueError, match="heartbeat_interval_seconds must be positive"): + LongRunningSettings(heartbeat_interval_seconds=0) + + def test_defaults_match_chat_ux(self): + # 3s interval + 15s stale gives ~5 heartbeats before a pod is considered + # dead — snug enough to recover conversations within a user's + # "reconnecting..." patience window. + s = LongRunningSettings() + assert s.heartbeat_interval_seconds == 3.0 + assert s.heartbeat_stale_threshold_seconds == 10.0 + + +class TestDebugKillTask: + """The opt-in debug-kill endpoint lets integration tests simulate a crash + against a deployed pod without restarting the whole app. Off by default + because exposing task cancellation bypasses the normal cleanup path.""" + + def test_endpoint_absent_by_default(self): + from starlette.testclient import TestClient + + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + resp = client.post("/_debug/kill_task/resp_x") + assert resp.status_code == 404 # route not registered + + def test_endpoint_registered_when_env_set(self, monkeypatch): + from starlette.testclient import TestClient + + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + # No in-flight task for this response_id on this pod → 404, not 405. + resp = client.post("/_debug/kill_task/resp_missing") + assert resp.status_code == 404 + assert "No in-flight task" in resp.json()["detail"] + + @pytest.mark.asyncio + async def test_cancels_tracked_task(self, monkeypatch): + """Direct-call variant: skip the TestClient (which is sync and blocks + the loop) and call the handler logic through _running_tasks directly. + Covers the important behavior: cancelling a tracked task propagates + CancelledError and the tracking dict is cleared by the done-callback. + """ + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + + cancel_event = asyncio.Event() + + async def long_running(): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancel_event.set() + raise + + task = asyncio.create_task(long_running()) + server._track_task("resp_tracked", task) + + # Yield once so the new task can start waiting on sleep(60). + await asyncio.sleep(0) + assert "resp_tracked" in server._running_tasks + + task.cancel() + # Expect CancelledError from awaiting the task itself, and the cancel + # event set inside the except handler before the re-raise. + with pytest.raises(asyncio.CancelledError): + await task + assert cancel_event.is_set() + # done-callback (scheduled on loop) clears the registration after the + # task completes — give it one more tick. + await asyncio.sleep(0) + assert "resp_tracked" not in server._running_tasks