Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
afd2d25
LongRunningAgentServer: add durable resume via heartbeat + CAS claim
dhruv0811 Apr 16, 2026
466a859
Update test_creates_schema_and_tables for new ADD COLUMN migration
dhruv0811 Apr 16, 2026
e7cfedd
Fix AttributeError when injecting conversation_id into a pydantic req…
dhruv0811 Apr 16, 2026
19df055
Tolerate InsufficientPrivilege in ADD COLUMN migrations
dhruv0811 Apr 16, 2026
97a5dcb
Include attempt_number in retrieve response for observability
dhruv0811 Apr 16, 2026
d3adee7
Add opt-in debug-kill endpoint for testing crash-resume on deployed apps
dhruv0811 Apr 16, 2026
d7c33b7
Apply ruff format + fix ty diagnostic on request_data dump
dhruv0811 Apr 20, 2026
f9f8a73
Log Background response created at INFO so response_id is visible in …
dhruv0811 Apr 20, 2026
d5666b2
Tag every SSE frame in stream retrieve with top-level response_id
dhruv0811 Apr 20, 2026
f2ffb6e
Self-heal open streams: call _try_claim_and_resume from _stream_retrieve
dhruv0811 Apr 20, 2026
6ee9f6c
Tighten heartbeat_stale_threshold_seconds default from 15s to 10s
dhruv0811 Apr 20, 2026
c2383f2
Add [durable] INFO-level lifecycle logs across the resume path
dhruv0811 Apr 20, 2026
cbd2b0b
Add public durable-resume repair helpers for openai + langchain
dhruv0811 Apr 21, 2026
5d70dde
Add pre_model_hook factory; WARN on skipped durability migrations
dhruv0811 Apr 21, 2026
62df014
Rename pre_model_hook factory to middleware factory
dhruv0811 Apr 21, 2026
4af26f0
Remove unused Callable import in checkpoint.py
dhruv0811 Apr 21, 2026
bc573fc
LongRunning: rotate conv_id on resume + full-history input sanitizer
dhruv0811 Apr 22, 2026
91360a9
Checkpoint saver: read-time repair so middleware stays optional
dhruv0811 Apr 22, 2026
5da7dbd
Session: auto-repair on get_items so middleware-free templates are safe
dhruv0811 Apr 22, 2026
51f0a4c
Stamp custom_inputs.attempt_number on resume so handlers can see retries
dhruv0811 Apr 22, 2026
f3f8eb3
Synthetic-output text: informative, scoped, nudges against re-running…
dhruv0811 Apr 22, 2026
40d7e09
Resume inherits prior attempt's completed tool outputs
dhruv0811 Apr 22, 2026
77cd8a8
Resume inheritance: include completed assistant message items
dhruv0811 Apr 22, 2026
7fecacd
Resume inheritance: reassemble mid-stream partial text from deltas
dhruv0811 Apr 22, 2026
6ef968f
Resume inheritance: also reassemble reasoning + function_call arg str…
dhruv0811 Apr 22, 2026
2f26e26
docs: update server.py docstrings for rotate+replay+inherit resume be…
dhruv0811 Apr 23, 2026
68ce276
Strip PR to bare minimum essentials for final durable-resume contract
dhruv0811 Apr 23, 2026
c23d9b6
Make build_tool_resume_repair internal (rename to _build_tool_resume_…
dhruv0811 Apr 23, 2026
8bd0718
server: add asyncio.sleep(0) yield point in stream loop
dhruv0811 Apr 23, 2026
4d0756e
Drop event: line from durable SSE frames
dhruv0811 Apr 23, 2026
07ded9d
Inheritance: drop completed message items, preserve only tool pairs
dhruv0811 Apr 23, 2026
7bcb1f3
Inheritance: hoist narrative messages after tool pairs instead of dro…
dhruv0811 Apr 23, 2026
8881563
Stable state — durable execution verified end-to-end
dhruv0811 Apr 23, 2026
ef8f86a
Refactor: extract shared sanitize_tool_items helper
dhruv0811 Apr 23, 2026
933333c
Consolidate [INTERRUPTED] synthetic output + simplify session API
dhruv0811 Apr 23, 2026
3f4cbe4
Merge remote-tracking branch 'origin/main' into pr-416
dhruv0811 Apr 23, 2026
d42ceb2
Apply ruff format + drop unused import
dhruv0811 Apr 23, 2026
3c1ca10
Move tool_repair.py out of long_running/
dhruv0811 Apr 23, 2026
e3db589
Add AGENTS.md for LongRunningAgentServer design
dhruv0811 Apr 28, 2026
7b9ae32
Revert "Add AGENTS.md for LongRunningAgentServer design"
dhruv0811 Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion integrations/langchain/src/databricks_langchain/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

from typing import Any
import copy
import logging
from typing import Any, Sequence

from databricks.sdk import WorkspaceClient
from databricks_ai_bridge.tool_repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT

logger = logging.getLogger(__name__)

try:
from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebaseClient, LakebasePool
Expand All @@ -16,6 +21,109 @@

_checkpoint_imports_available = False

try:
from langchain_core.messages import AIMessage, ToolMessage

_message_imports_available = True
except ImportError:
AIMessage = object # type: ignore
ToolMessage = object # type: ignore
_message_imports_available = False


def _build_tool_resume_repair(messages: Sequence[Any]) -> list[Any]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these synthetic events?

from our discussion in person, i think it should be sufficient for us to serialize the entire message history from the crashed conversations into a single user message with a prompt to the LLM about "the agent task runnig this crashed, contniue from here"

(maybe can say recovering from crash in the text also to explain why there is duplicatd content like two tool calls in a row)

pros/cons list from talking to claude:

Image

imo just prose recovery is much cleaner and doesn't rely on agent authors to know to use these specific APIs

"""Build synthetic ``ToolMessage`` responses for orphan tool calls.

Internal helper used by ``_repair_loaded_checkpoint_tuple``. When a
LangGraph run is killed mid-tool, the checkpointer preserves the
trailing ``AIMessage.tool_calls`` but the paired ``ToolMessage``s
never land. Replaying that state to the LLM fails because the API
(Anthropic in particular) requires every ``tool_use`` to be
immediately followed by a matching ``tool_result``.

Walks the trailing assistant turn (the last contiguous block of
``AIMessage`` / ``ToolMessage``) and returns a synthetic
``ToolMessage`` for each ``tool_call`` id that lacks a matching
``ToolMessage.tool_call_id``. The caller appends these to the
``messages`` channel before the next model call.
"""
if not _message_imports_available or not messages:
return []

# Trailing assistant turn: walk backwards until we hit a non-assistant/
# non-tool message. That block is the "pending" turn whose tool_use ↔
# tool_result pairing we need to enforce.
trailing_start = len(messages)
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], (AIMessage, ToolMessage)):
trailing_start = i
else:
break

tool_call_ids: list[str] = []
answered: set[str] = set()
for msg in messages[trailing_start:]:
if isinstance(msg, AIMessage):
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
if tc_id and tc_id not in tool_call_ids:
tool_call_ids.append(tc_id)
elif isinstance(msg, ToolMessage):
tcid = getattr(msg, "tool_call_id", None)
if tcid:
answered.add(tcid)

orphans = [tc_id for tc_id in tool_call_ids if tc_id not in answered]
return [
ToolMessage(tool_call_id=tc_id, content=DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT)
for tc_id in orphans
]


def _repair_loaded_checkpoint_tuple(tup: Any) -> Any:
"""Return a copy of ``tup`` with orphan tool_calls in its ``messages``
channel closed by synthetic ``ToolMessage`` s.

Called on every ``(a)get_tuple`` to make the served checkpoint
protocol-valid (every ``tool_use`` paired with a ``tool_result``)
transparently. A kill between the ``model`` and ``tools`` nodes leaves
the trailing ``AIMessage.tool_calls`` unpaired; on the NEXT turn that
state would otherwise leak into the LLM and be rejected by the
provider's pairing check.

Idempotent — ``_build_tool_resume_repair`` is a no-op when state is
already clean. Cheap — the walk is O(trailing-turn).

Side effect: the synthetic ``ToolMessage`` s added here become part of
the state LangGraph writes on the NEXT node boundary, so the repair
self-heals the DB row over time rather than re-computing on every read.
"""
if tup is None or not _message_imports_available:
return tup

checkpoint = getattr(tup, "checkpoint", None)
if not isinstance(checkpoint, dict):
return tup
channel_values = checkpoint.get("channel_values")
if not isinstance(channel_values, dict):
return tup
messages = channel_values.get("messages")
if not isinstance(messages, list) or not messages:
return tup

repair = _build_tool_resume_repair(messages)
if not repair:
return tup

logger.info(
"[durable] checkpoint read-time repair: injected %d synthetic ToolMessage(s)",
len(repair),
)
new_checkpoint = copy.copy(checkpoint)
new_checkpoint["channel_values"] = dict(channel_values)
new_checkpoint["channel_values"]["messages"] = list(messages) + list(repair)
return tup._replace(checkpoint=new_checkpoint)


class CheckpointSaver(PostgresSaver):
"""
Expand Down Expand Up @@ -68,6 +176,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._lakebase.close()
return False

def get_tuple(self, config):
"""Return the checkpoint tuple, with trailing orphan tool_calls paired."""
return _repair_loaded_checkpoint_tuple(super().get_tuple(config))


class AsyncCheckpointSaver(AsyncPostgresSaver):
"""
Expand Down Expand Up @@ -122,3 +234,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit async context manager and close the connection pool."""
await self._lakebase.close()
return False

async def aget_tuple(self, config):
"""Return the checkpoint tuple, with trailing orphan tool_calls paired."""
return _repair_loaded_checkpoint_tuple(await super().aget_tuple(config))
79 changes: 79 additions & 0 deletions integrations/langchain/tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,82 @@ async def test_async_checkpoint_saver_branch_resource_path(monkeypatch):

assert "host=auto-db-host" in test_pool.conninfo
assert saver._lakebase._is_autoscaling is True


class TestReadTimeCheckpointRepair:
"""Read-time repair: aget_tuple / get_tuple returns a state where every
trailing ``AIMessage.tool_calls`` is paired with a ``ToolMessage``. Keeps
user-space free of middleware when the app is built on our savers."""

def _make_tuple(self, messages):
from collections import namedtuple

FakeTuple = namedtuple(
"CheckpointTuple",
["config", "checkpoint", "metadata", "parent_config", "pending_writes"],
)
return FakeTuple(
config={},
checkpoint={
"v": 1,
"id": "ckpt",
"channel_values": {"messages": list(messages)},
},
metadata={},
parent_config=None,
pending_writes=None,
)

def test_repairs_trailing_orphan_tool_call(self):
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage

from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple

tup = self._make_tuple(
[
HumanMessage("hi"),
AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]),
]
)
repaired = _repair_loaded_checkpoint_tuple(tup)
msgs = repaired.checkpoint["channel_values"]["messages"]
assert len(msgs) == 3
assert isinstance(msgs[-1], ToolMessage)
assert msgs[-1].tool_call_id == "c1"

def test_noop_when_state_is_clean(self):
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage

from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple

tup = self._make_tuple(
[
HumanMessage("hi"),
AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]),
ToolMessage(tool_call_id="c1", content="ok"),
AIMessage(content="done"),
]
)
repaired = _repair_loaded_checkpoint_tuple(tup)
# No repair added → tuple unchanged.
assert repaired is tup

def test_none_tuple_passes_through(self):
from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple

assert _repair_loaded_checkpoint_tuple(None) is None

def test_does_not_mutate_original_messages_list(self):
from langchain_core.messages import AIMessage, HumanMessage

from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple

original_messages = [
HumanMessage("hi"),
AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]),
]
tup = self._make_tuple(original_messages)
original_len = len(original_messages)
_repair_loaded_checkpoint_tuple(tup)
# Calling repair must NOT mutate the caller's original list.
assert len(original_messages) == original_len
22 changes: 22 additions & 0 deletions integrations/openai/src/databricks_openai/agents/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,27 @@ async def main():
DEFAULT_TOKEN_CACHE_DURATION_SECONDS,
AsyncLakebaseSQLAlchemy,
)
from databricks_ai_bridge.tool_repair import sanitize_tool_items

_session_imports_available = True
except ImportError:
SQLAlchemySession = object # type: ignore
DEFAULT_TOKEN_CACHE_DURATION_SECONDS = None # type: ignore
DEFAULT_POOL_RECYCLE_SECONDS = None # type: ignore
sanitize_tool_items = None # type: ignore
_session_imports_available = False

logger = logging.getLogger(__name__)


def _sanitize_items(items: list[Any]) -> list[Any]:
"""Session-scoped wrapper around :func:`sanitize_tool_items` that only
sets the log prefix. Kept as a one-liner so existing
``self._sanitize_items`` call sites stay stable.
"""
return sanitize_tool_items(items, log_prefix="[durable] session items sanitized")


class AsyncDatabricksSession(SQLAlchemySession):
"""
Async OpenAI Agents SDK Session implementation for Databricks Lakebase.
Expand Down Expand Up @@ -179,6 +189,18 @@ async def _ensure_tables(self) -> None:
await self._lakebase.create_schema()
await super()._ensure_tables()

async def get_items(self, limit: Optional[int] = None) -> list[Any]:
"""Return session items, always repaired for protocol validity.

The returned list has every ``function_call`` paired with a
``function_call_output`` — orphans from a durable-resume crash get
a synthetic output appended, and duplicates get deduped. The
underlying DB rows are not modified; this is a pure in-memory
filter, cheap to re-run on every call.
"""
items = await super().get_items(limit=limit)
return _sanitize_items(items)

@classmethod
def _build_cache_key(
cls,
Expand Down
110 changes: 110 additions & 0 deletions integrations/openai/tests/unit_tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,116 @@ def test_init_branch_resource_path_resolves_host(
)


class TestSanitizeItems:
"""Pure walker that reconciles orphan function_call / function_call_output
items. Shared by both the destructive ``repair()`` path and the read-time
``get_items()`` filter."""

def _items_for(self, *types_and_ids):
# Helper: build items from (type, call_id) tuples.
items = []
for spec in types_and_ids:
if isinstance(spec, str):
items.append({"role": "user", "content": spec})
else:
t, cid = spec
items.append(
{"type": t, "call_id": cid, "name": "f", "arguments": "{}"}
if t == "function_call"
else {"type": t, "call_id": cid, "output": "ok"}
)
return items

def test_noop_when_clean_returns_same_list(self):
from databricks_openai.agents.session import _sanitize_items

items = self._items_for(
"hi",
("function_call", "c1"),
("function_call_output", "c1"),
"done",
)
out = _sanitize_items(items)
assert out is items # caller can skip re-persistence

def test_injects_synthetic_output_for_orphan_call(self):
from databricks_openai.agents.session import _sanitize_items

items = self._items_for("hi", ("function_call", "c1"))
out = _sanitize_items(items)
assert len(out) == 3
assert out[-1]["type"] == "function_call_output"
assert out[-1]["call_id"] == "c1"

def test_injects_for_multiple_orphan_calls(self):
# Scenario the user hit: multiple parallel tool_calls, all orphaned.
from databricks_openai.agents.session import _sanitize_items

items = self._items_for(
"hi",
("function_call", "c1"),
("function_call", "c2"),
("function_call", "c3"),
)
out = _sanitize_items(items)
calls = [i for i in out if i.get("type") == "function_call"]
outputs = [i for i in out if i.get("type") == "function_call_output"]
assert len(calls) == 3
assert len(outputs) == 3
assert {o["call_id"] for o in outputs} == {"c1", "c2", "c3"}

def test_drops_orphan_output_with_no_matching_call(self):
from databricks_openai.agents.session import _sanitize_items

items = self._items_for("hi", ("function_call_output", "ghost"))
out = _sanitize_items(items)
assert all(i.get("type") != "function_call_output" for i in out)

def test_dedupes_duplicate_calls_and_outputs(self):
from databricks_openai.agents.session import _sanitize_items

items = self._items_for(
("function_call", "c1"),
("function_call", "c1"),
("function_call_output", "c1"),
("function_call_output", "c1"),
)
out = _sanitize_items(items)
assert len(out) == 2


class TestAsyncGetItemsAutoRepair:
"""get_items() always applies read-time repair. Uses a minimal subclass
that bypasses parent SQLAlchemySession init so we can exercise the
override without a DB."""

def _fake_session(self, items):
from databricks_openai.agents.session import AsyncDatabricksSession, _sanitize_items

class _FakeSession(AsyncDatabricksSession):
def __init__(self, stored):
# Bypass parent init — only need the stored items.
self._stored = stored

async def get_items(self, limit=None):
return _sanitize_items(list(self._stored))

return _FakeSession(items)

@pytest.mark.asyncio
async def test_auto_repair_injects_synthetic_outputs(self):
sess = self._fake_session(
[
{"role": "user", "content": "hi"},
{"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"},
{"type": "function_call", "call_id": "c2", "name": "f", "arguments": "{}"},
]
)
items = await sess.get_items()
synth = [i for i in items if i.get("type") == "function_call_output"]
assert len(synth) == 2


# =============================================================================
# Schema Tests
# =============================================================================
Expand Down
Loading
Loading