diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/__init__.py b/py/src/braintrust/wrappers/claude_agent_sdk/__init__.py index 8b596860..1d44358c 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/__init__.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/__init__.py @@ -19,7 +19,7 @@ from braintrust.logger import NOOP_SPAN, current_span, init_logger -from ._wrapper import _create_client_wrapper_class, _create_tool_wrapper_class, _wrap_tool_factory +from ._wrapper import _create_client_wrapper_class, _create_tool_wrapper_class logger = logging.getLogger(__name__) @@ -69,7 +69,6 @@ def setup_claude_agent_sdk( original_client = claude_agent_sdk.ClaudeSDKClient if hasattr(claude_agent_sdk, "ClaudeSDKClient") else None original_tool_class = claude_agent_sdk.SdkMcpTool if hasattr(claude_agent_sdk, "SdkMcpTool") else None - original_tool_fn = claude_agent_sdk.tool if hasattr(claude_agent_sdk, "tool") else None if original_client: wrapped_client = _create_client_wrapper_class(original_client) @@ -89,15 +88,6 @@ def setup_claude_agent_sdk( if getattr(module, "SdkMcpTool", None) is original_tool_class: setattr(module, "SdkMcpTool", wrapped_tool_class) - if original_tool_fn: - wrapped_tool_fn = _wrap_tool_factory(original_tool_fn) - claude_agent_sdk.tool = wrapped_tool_fn - - for module in list(sys.modules.values()): - if module and hasattr(module, "tool"): - if getattr(module, "tool", None) is original_tool_fn: - setattr(module, "tool", wrapped_tool_fn) - return True except ImportError: # Not installed - this is expected when using auto_instrument() diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py index e019241d..1cefc6d4 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py @@ -1,6 +1,7 @@ import asyncio +import collections import dataclasses -import logging +import json import threading import time from collections.abc import AsyncGenerator, AsyncIterable @@ -25,7 +26,6 @@ ) -log = logging.getLogger(__name__) _thread_local = threading.local() @@ -43,6 +43,8 @@ class _ActiveToolSpan: raw_name: str display_name: str input: Any + tool_use_id: str | None = None + parent_tool_use_id: str | None = None handler_active: bool = False @property @@ -79,10 +81,6 @@ def release(self) -> None: _NOOP_ACTIVE_TOOL_SPAN = _NoopActiveToolSpan() -def _log_tracing_warning(exc: Exception) -> None: - log.warning("Error in tracing code", exc_info=exc) - - def _parse_tool_name(tool_name: Any) -> ParsedToolName: raw_name = str(tool_name) if tool_name is not None else DEFAULT_TOOL_NAME @@ -187,26 +185,6 @@ def __init__( return WrappedSdkMcpTool -def _wrap_tool_factory(tool_fn: Any) -> Any: - """Wrap the tool() factory so decorated handlers inherit the active TOOL span.""" - - def wrapped_tool(*args: Any, **kwargs: Any) -> Any: - result = tool_fn(*args, **kwargs) - if not callable(result): - return result - - def wrapped_decorator(handler_fn: Any) -> Any: - tool_def = result(handler_fn) - if tool_def and hasattr(tool_def, "handler"): - tool_name = getattr(tool_def, "name", DEFAULT_TOOL_NAME) - tool_def.handler = _wrap_tool_handler(tool_def.handler, tool_name) - return tool_def - - return wrapped_decorator - - return wrapped_tool - - def _wrap_tool_handler(handler: Any, tool_name: Any) -> Any: """Wrap a tool handler so nested spans execute under the stream-based TOOL span.""" if hasattr(handler, "_braintrust_wrapped"): @@ -236,15 +214,29 @@ async def wrapped_handler(args: Any) -> Any: return wrapped_handler +def _make_dispatch_key(tool_name: str, tool_input: Any) -> tuple[str, str]: + """Create a hashable key for dispatch queue lookup from tool name and input.""" + try: + input_sig = json.dumps(tool_input, sort_keys=True, default=str) + except (TypeError, ValueError): + input_sig = repr(tool_input) + return (tool_name, input_sig) + + class ToolSpanTracker: def __init__(self): self._active_spans: dict[str, _ActiveToolSpan] = {} - self._pending_task_link_tool_use_ids: set[str] = set() + # Per-(tool_name, input_signature) FIFO queue of tool_use_ids. + # Used by acquire_span_for_handler to disambiguate identical concurrent + # tool calls (same name + same input) from sibling subagents. + self._dispatch_queues: dict[tuple[str, str], collections.deque[str]] = {} def start_tool_spans(self, message: Any, llm_span_export: str | None) -> None: if llm_span_export is None or not hasattr(message, "content"): return + message_parent_tool_use_id = getattr(message, "parent_tool_use_id", None) + for block in message.content: if type(block).__name__ != BlockClassName.TOOL_USE: continue @@ -277,14 +269,17 @@ def start_tool_spans(self, message: Any, llm_span_export: str | None) -> None: metadata=metadata, parent=llm_span_export, ) + tool_input = getattr(block, "input", None) self._active_spans[tool_use_id] = _ActiveToolSpan( span=tool_span, raw_name=parsed_tool_name.raw_name, display_name=parsed_tool_name.display_name, - input=getattr(block, "input", None), + input=tool_input, + tool_use_id=tool_use_id, + parent_tool_use_id=message_parent_tool_use_id, ) - if parsed_tool_name.display_name == "Agent": - self._pending_task_link_tool_use_ids.add(tool_use_id) + dispatch_key = _make_dispatch_key(parsed_tool_name.raw_name, tool_input) + self._dispatch_queues.setdefault(dispatch_key, collections.deque()).append(tool_use_id) def finish_tool_spans(self, message: Any) -> None: if not hasattr(message, "content"): @@ -300,26 +295,34 @@ def finish_tool_spans(self, message: Any) -> None: self._end_tool_span(str(tool_use_id), tool_result_block=block) - def cleanup(self, end_time: float | None = None, exclude_tool_use_ids: frozenset[str] | None = None) -> None: + def cleanup_context( + self, + parent_tool_use_id: str | None, + *, + end_time: float | None = None, + exclude_ids: frozenset[str] = frozenset(), + ) -> None: + """Close tool spans belonging to one subagent context. + + Skips any span whose tool_use_id is in exclude_ids (live Agent spans). + Called before starting a new LLM span for that context. + """ for tool_use_id in list(self._active_spans): - if exclude_tool_use_ids and tool_use_id in exclude_tool_use_ids: + if tool_use_id in exclude_ids: continue + if self._active_spans[tool_use_id].parent_tool_use_id != parent_tool_use_id: + continue + self._end_tool_span(tool_use_id, end_time=end_time) + + def cleanup_all(self, end_time: float | None = None) -> None: + """Close all remaining active spans. Called at end-of-stream.""" + for tool_use_id in list(self._active_spans): self._end_tool_span(tool_use_id, end_time=end_time) @property def has_active_spans(self) -> bool: return bool(self._active_spans) - @property - def pending_task_link_tool_use_ids(self) -> frozenset[str]: - return frozenset(self._pending_task_link_tool_use_ids) - - def mark_task_started(self, tool_use_id: Any) -> None: - if tool_use_id is None: - return - - self._pending_task_link_tool_use_ids.discard(str(tool_use_id)) - def acquire_span_for_handler(self, tool_name: Any, args: Any) -> _ActiveToolSpan | None: parsed_tool_name = _parse_tool_name(tool_name) candidate_names = list( @@ -333,21 +336,55 @@ def acquire_span_for_handler(self, tool_name: Any, args: Any) -> _ActiveToolSpan and (active_tool_span.raw_name in candidate_names or active_tool_span.display_name in candidate_names) ] - matched_span = _match_tool_span_for_handler(candidates, args) + matched_span = self._match_via_dispatch_queue(parsed_tool_name.raw_name, args, candidates) + if matched_span is None: + matched_span = _match_tool_span_for_handler(candidates, args) if matched_span is None: return None matched_span.activate() return matched_span + def _match_via_dispatch_queue( + self, raw_name: str, args: Any, candidates: list[_ActiveToolSpan] + ) -> _ActiveToolSpan | None: + """Use the dispatch queue to match by tool_use_id when multiple identical + candidates exist (same name + same input from different subagents).""" + dispatch_key = _make_dispatch_key(raw_name, args) + queue = self._dispatch_queues.get(dispatch_key) + if not queue: + return None + + # Pop tool_use_ids until we find one that corresponds to an available + # (non-handler_active) candidate, skipping stale entries. + candidate_ids = {c.tool_use_id for c in candidates} + while queue: + tool_use_id = queue.popleft() + if tool_use_id in candidate_ids: + for candidate in candidates: + if candidate.tool_use_id == tool_use_id: + return candidate + + return None + def _end_tool_span( self, tool_use_id: str, tool_result_block: Any | None = None, end_time: float | None = None ) -> None: active_tool_span = self._active_spans.pop(tool_use_id, None) - self._pending_task_link_tool_use_ids.discard(tool_use_id) if active_tool_span is None: return + # Remove from dispatch queue so stale entries don't accumulate. + dispatch_key = _make_dispatch_key(active_tool_span.raw_name, active_tool_span.input) + queue = self._dispatch_queues.get(dispatch_key) + if queue: + try: + queue.remove(tool_use_id) + except ValueError: + pass + if not queue: + del self._dispatch_queues[dispatch_key] + if tool_result_block is None: active_tool_span.span.end(end_time=end_time) return @@ -396,236 +433,351 @@ def _activate_tool_span_for_handler(tool_name: Any, args: Any) -> _ActiveToolSpa return tool_span_tracker.acquire_span_for_handler(tool_name, args) or _NOOP_ACTIVE_TOOL_SPAN -class LLMSpanTracker: - """Manages LLM span lifecycle for Claude Agent SDK message streams. +def _msg_field(message: Any, field: str) -> Any: + """Read a field from a system message, falling back to message.data for older SDK versions. - Message flow per turn: - 1. UserMessage (tool results) -> mark the time when next LLM will start - 2. AssistantMessage - LLM response arrives -> create span with the marked start time, ending previous span - 3. ResultMessage - usage metrics -> log to span + SDK >= 0.1.11 exposes TaskStartedMessage / TaskProgressMessage / + TaskNotificationMessage with fields as top-level attributes. + SDK 0.1.10 uses a flat SystemMessage(subtype, data=) + where task fields live directly in data (e.g. data["task_id"]). + """ + value = getattr(message, field, None) + if value is not None: + return value + # Older SDK: message.data is the full raw payload dict with task fields at its top level. + data = getattr(message, "data", None) + if isinstance(data, dict): + return data.get(field) + return None + + +def _task_span_name(message: Any, task_id: str) -> str: + return _msg_field(message, "description") or _msg_field(message, "task_type") or f"Task {task_id}" + + +def _task_metadata(message: Any) -> dict[str, Any]: + return { + k: v + for k, v in { + "task_id": _msg_field(message, "task_id"), + "session_id": _msg_field(message, "session_id"), + "tool_use_id": _msg_field(message, "tool_use_id"), + "task_type": _msg_field(message, "task_type"), + "status": _msg_field(message, "status"), + "last_tool_name": _msg_field(message, "last_tool_name"), + "usage": _msg_field(message, "usage"), + }.items() + if v is not None + } + + +def _task_output(message: Any) -> dict[str, Any] | None: + summary = _msg_field(message, "summary") + output_file = _msg_field(message, "output_file") + + if summary is None and output_file is None: + return None + + return { + k: v + for k, v in { + "summary": summary, + "output_file": output_file, + }.items() + if v is not None + } - We end the previous span when the next AssistantMessage arrives, using the marked - start time to ensure sequential spans (no overlapping LLM spans). + +def _message_starts_subagent_tool(message: Any) -> bool: + if not hasattr(message, "content"): + return False + + for block in message.content: + if type(block).__name__ != BlockClassName.TOOL_USE: + continue + if getattr(block, "name", None) == "Agent": + return True + + return False + + +@dataclasses.dataclass +class _AgentContext: + """Per-subagent-context state, keyed by parent_tool_use_id (None = orchestrator).""" + + llm_span: Any | None = None + llm_parent_export: str | None = None + llm_output: list[dict[str, Any]] | None = None + next_llm_start: float | None = None + task_span: Any | None = None + task_confirmed: bool = False + + +class ContextTracker: + """Single consumer of the raw SDK message stream. + + Replaces LLMSpanTracker + TaskEventSpanTracker with unified per-subagent + context tracking. Owns a private ToolSpanTracker instance. """ - def __init__(self, query_start_time: float | None = None): - self.current_span: Any | None = None - self.current_span_export: str | None = None - self.current_parent_export: str | None = None - self.current_output: list[dict[str, Any]] | None = None - self.next_start_time: float | None = query_start_time + def __init__( + self, + root_span: Any, + prompt: Any, + query_start_time: float | None = None, + captured_messages: list[dict[str, Any]] | None = None, + ) -> None: + self._root_span = root_span + self._root_span_export = root_span.export() + self._prompt = prompt + self._captured_messages = captured_messages # logged to root span on first add() + + self._tool_tracker = ToolSpanTracker() + self._contexts: dict[str | None, _AgentContext] = {None: _AgentContext(next_llm_start=query_start_time)} + self._active_key: str | None = None + self._task_order: list[str | None] = [] + + self._final_results: list[dict[str, Any]] = [] + self._task_events: list[dict[str, Any]] = [] + + _thread_local.tool_span_tracker = self._tool_tracker + + # -- public API -- + + def add(self, message: Any) -> None: + """Consume one SDK message and update spans accordingly.""" + if self._captured_messages is not None: + if self._captured_messages: + self._root_span.log(input=self._captured_messages) + self._captured_messages = None + + message_type = type(message).__name__ + if message_type == MessageClassName.ASSISTANT: + self._handle_assistant(message) + elif message_type == MessageClassName.USER: + self._handle_user(message) + elif message_type == MessageClassName.RESULT: + self._handle_result(message) + elif message_type in SYSTEM_MESSAGE_TYPES: + self._handle_system(message) + + def log_output(self) -> None: + """Log the last accumulated assistant message as the root span output.""" + if self._final_results: + self._root_span.log(output=self._final_results[-1]) + + def log_tasks(self) -> None: + """Flush accumulated task events to the root span metadata.""" + if self._task_events: + self._root_span.log(metadata={"task_events": self._task_events}) + + def cleanup(self) -> None: + """End all open LLM spans, TASK spans, and TOOL spans; clear thread-local.""" + for ctx in self._contexts.values(): + if ctx.llm_span: + ctx.llm_span.end() + ctx.llm_span = None + if ctx.task_span: + ctx.task_span.end() + ctx.task_span = None + self._task_order.clear() + self._tool_tracker.cleanup_all() + if hasattr(_thread_local, "tool_span_tracker"): + delattr(_thread_local, "tool_span_tracker") + + # -- internal handlers -- + + def _handle_assistant(self, message: Any) -> None: + incoming_parent = getattr(message, "parent_tool_use_id", None) + self._active_key = incoming_parent + ctx = self._get_context(incoming_parent) + + # Close dangling tool spans from the previous turn in this context. + if ctx.llm_span and self._tool_tracker.has_active_spans: + self._tool_tracker.cleanup_context( + incoming_parent, + end_time=ctx.next_llm_start or time.time(), + exclude_ids=self._live_agent_tool_use_ids(), + ) + + parent_export = self._llm_parent_for_message(message) + final_content, extended = self._start_or_merge_llm_span(message, parent_export, ctx) + + llm_export = ctx.llm_span.export() if ctx.llm_span else None + self._tool_tracker.start_tool_spans(message, llm_export) - def get_next_start_time(self) -> float: - return self.next_start_time if self.next_start_time is not None else time.time() + self._register_pending_agent_contexts(message) - def start_llm_span( + if final_content: + if extended and self._final_results and self._final_results[-1].get("role") == "assistant": + self._final_results[-1] = final_content + else: + self._final_results.append(final_content) + + def _handle_user(self, message: Any) -> None: + self._tool_tracker.finish_tool_spans(message) + has_tool_results = False + if hasattr(message, "content"): + has_tool_results = any(type(b).__name__ == BlockClassName.TOOL_RESULT for b in message.content) + content = _serialize_content_blocks(message.content) + self._final_results.append({"content": content, "role": "user"}) + if has_tool_results: + user_parent = getattr(message, "parent_tool_use_id", None) + resolved_key = user_parent if user_parent is not None else self._active_key + self._get_context(resolved_key).next_llm_start = time.time() + + def _handle_result(self, message: Any) -> None: + self._active_key = None + if hasattr(message, "usage"): + usage_metrics = _extract_usage_from_result_message(message) + ctx = self._get_context(None) + if ctx.llm_span and usage_metrics: + ctx.llm_span.log(metrics=usage_metrics) + result_metadata = { + k: v + for k, v in { + "num_turns": getattr(message, "num_turns", None), + "session_id": getattr(message, "session_id", None), + }.items() + if v is not None + } + if result_metadata: + self._root_span.log(metadata=result_metadata) + + def _handle_system(self, message: Any) -> None: + agent_span_export = self._tool_tracker.get_span_export(_msg_field(message, "tool_use_id")) + self._process_task_event(message, agent_span_export) + self._task_events.append(_serialize_system_message(message)) + + # -- internal helpers -- + + def _get_context(self, key: str | None) -> _AgentContext: + ctx = self._contexts.get(key) + if ctx is None: + ctx = _AgentContext() + self._contexts[key] = ctx + return ctx + + def _register_pending_agent_contexts(self, message: Any) -> None: + """Pre-create _AgentContext for Agent tool calls (task_confirmed=False).""" + if not hasattr(message, "content"): + return + for block in message.content: + if type(block).__name__ == BlockClassName.TOOL_USE and getattr(block, "name", None) == "Agent": + tool_use_id = getattr(block, "id", None) + if tool_use_id: + self._get_context(str(tool_use_id)) + + def _live_agent_tool_use_ids(self) -> frozenset[str]: + """Return tool_use_ids of Agent spans that must not be closed yet.""" + result: set[str] = set() + for key, ctx in self._contexts.items(): + if key is None: + continue + if not ctx.task_confirmed or ctx.task_span is not None: + result.add(key) + return frozenset(result) + + def _llm_parent_for_message(self, message: Any) -> str: + """Determine the parent span export for an incoming AssistantMessage.""" + parent_tool_use_id = getattr(message, "parent_tool_use_id", None) + if parent_tool_use_id is not None: + ctx = self._contexts.get(str(parent_tool_use_id)) + if ctx is not None and ctx.task_span is not None: + return ctx.task_span.export() + + if _message_starts_subagent_tool(message): + return self._root_span_export + + for key in reversed(self._task_order): + ctx = self._contexts.get(key) + if ctx is not None and ctx.task_span is not None: + return ctx.task_span.export() + + return self._root_span_export + + def _start_or_merge_llm_span( self, message: Any, - prompt: Any, - conversation_history: list[dict[str, Any]], - parent_export: str | None = None, - start_time: float | None = None, + parent_export: str | None, + ctx: _AgentContext, ) -> tuple[dict[str, Any] | None, bool]: - """Start a new LLM span, ending the previous one if it exists.""" + """Start a new LLM span or extend the existing one via merge.""" current_message = _serialize_assistant_message(message) + # Merge path. if ( - self.current_span - and self.next_start_time is None - and self.current_parent_export == parent_export + ctx.llm_span + and ctx.next_llm_start is None + and ctx.llm_parent_export == parent_export and current_message is not None ): - merged_message = _merge_assistant_messages( - self.current_output[0] if self.current_output else None, + merged = _merge_assistant_messages( + ctx.llm_output[0] if ctx.llm_output else None, current_message, ) - if merged_message is not None: - self.current_output = [merged_message] - self.current_span.log(output=self.current_output) - return merged_message, True + if merged is not None: + ctx.llm_output = [merged] + ctx.llm_span.log(output=ctx.llm_output) + return merged, True - resolved_start_time = start_time if start_time is not None else self.get_next_start_time() + # New span path. + resolved_start = ctx.next_llm_start or time.time() first_token_time = time.time() - if self.current_span: - self.current_span.end(end_time=resolved_start_time) + if ctx.llm_span: + ctx.llm_span.end(end_time=resolved_start) final_content, span = _create_llm_span_for_messages( [message], - prompt, - conversation_history, + self._prompt, + self._final_results, parent=parent_export, - start_time=resolved_start_time, + start_time=resolved_start, ) if span is not None: - span.log(metrics={"time_to_first_token": max(0.0, first_token_time - resolved_start_time)}) - self.current_span = span - self.current_span_export = span.export() if span else None - self.current_parent_export = parent_export - self.current_output = [final_content] if final_content is not None else None - self.next_start_time = None + span.log(metrics={"time_to_first_token": max(0.0, first_token_time - resolved_start)}) + ctx.llm_span = span + ctx.llm_parent_export = parent_export + ctx.llm_output = [final_content] if final_content is not None else None + ctx.next_llm_start = None return final_content, False - def mark_next_llm_start(self) -> None: - """Mark when the next LLM call will start (after tool results).""" - self.next_start_time = time.time() - - def log_usage(self, usage_metrics: dict[str, float]) -> None: - """Log usage metrics to the current LLM span.""" - if self.current_span and usage_metrics: - self.current_span.log(metrics=usage_metrics) - - def cleanup(self) -> None: - """End any unclosed spans.""" - if self.current_span: - self.current_span.end() - self.current_span = None - self.current_span_export = None - self.current_parent_export = None - self.current_output = None - - -class TaskEventSpanTracker: - def __init__(self, root_span_export: str, tool_tracker: ToolSpanTracker): - self._root_span_export = root_span_export - self._tool_tracker = tool_tracker - self._active_spans: dict[str, Any] = {} - self._task_span_by_tool_use_id: dict[str, Any] = {} - self._active_task_order: list[str] = [] - - def process(self, message: Any) -> None: - task_id = getattr(message, "task_id", None) + def _process_task_event(self, message: Any, agent_span_export: str | None) -> None: + """Handle TaskStarted / TaskProgress / TaskNotification system messages.""" + task_id = _msg_field(message, "task_id") if task_id is None: return - task_id = str(task_id) + tool_use_id = _msg_field(message, "tool_use_id") + tool_use_id_str = str(tool_use_id) if tool_use_id is not None else None + ctx = self._get_context(tool_use_id_str) message_type = type(message).__name__ - task_span = self._active_spans.get(task_id) - if task_span is None: - task_span = start_span( - name=self._span_name(message, task_id), + if ctx.task_span is None: + ctx.task_span = start_span( + name=_task_span_name(message, task_id), span_attributes={"type": SpanTypeAttribute.TASK}, - metadata=self._metadata(message), - parent=self._parent_export(message), + metadata=_task_metadata(message), + parent=agent_span_export or self._root_span_export, ) - self._active_spans[task_id] = task_span - self._active_task_order.append(task_id) - tool_use_id = getattr(message, "tool_use_id", None) - if tool_use_id is not None: - tool_use_id = str(tool_use_id) - self._task_span_by_tool_use_id[tool_use_id] = task_span - self._tool_tracker.mark_task_started(tool_use_id) + ctx.task_confirmed = True + self._task_order.append(tool_use_id_str) else: update: dict[str, Any] = {} - metadata = self._metadata(message) + metadata = _task_metadata(message) if metadata: update["metadata"] = metadata - - output = self._output(message) + output = _task_output(message) if output is not None: update["output"] = output - if update: - task_span.log(**update) - - if self._should_end(message_type): - tool_use_id = getattr(message, "tool_use_id", None) - if tool_use_id is not None: - self._task_span_by_tool_use_id.pop(str(tool_use_id), None) - task_span.end() - del self._active_spans[task_id] - self._active_task_order = [ - active_task_id for active_task_id in self._active_task_order if active_task_id != task_id - ] - - @property - def active_tool_use_ids(self) -> frozenset[str]: - return frozenset(self._task_span_by_tool_use_id.keys()) - - def cleanup(self) -> None: - for task_id, span in list(self._active_spans.items()): - span.end() - del self._active_spans[task_id] - self._task_span_by_tool_use_id.clear() - self._active_task_order.clear() - - def parent_export_for_message(self, message: Any, fallback_export: str) -> str: - parent_tool_use_id = getattr(message, "parent_tool_use_id", None) - if parent_tool_use_id is None: - if _message_starts_subagent_tool(message): - return fallback_export - active_task_export = self._latest_active_task_export() - return active_task_export or fallback_export - - task_span = self._task_span_by_tool_use_id.get(str(parent_tool_use_id)) - if task_span is not None: - return task_span.export() - - active_task_export = self._latest_active_task_export() - return active_task_export or fallback_export - - def _latest_active_task_export(self) -> str | None: - for task_id in reversed(self._active_task_order): - task_span = self._active_spans.get(task_id) - if task_span is not None: - return task_span.export() - - return None - - def _parent_export(self, message: Any) -> str: - return self._tool_tracker.get_span_export(getattr(message, "tool_use_id", None)) or self._root_span_export + ctx.task_span.log(**update) - def _span_name(self, message: Any, task_id: str) -> str: - return getattr(message, "description", None) or getattr(message, "task_type", None) or f"Task {task_id}" - - def _metadata(self, message: Any) -> dict[str, Any]: - metadata = { - k: v - for k, v in { - "task_id": getattr(message, "task_id", None), - "session_id": getattr(message, "session_id", None), - "tool_use_id": getattr(message, "tool_use_id", None), - "task_type": getattr(message, "task_type", None), - "status": getattr(message, "status", None), - "last_tool_name": getattr(message, "last_tool_name", None), - "usage": getattr(message, "usage", None), - }.items() - if v is not None - } - return metadata - - def _output(self, message: Any) -> dict[str, Any] | None: - summary = getattr(message, "summary", None) - output_file = getattr(message, "output_file", None) - - if summary is None and output_file is None: - return None - - return { - k: v - for k, v in { - "summary": summary, - "output_file": output_file, - }.items() - if v is not None - } - - def _should_end(self, message_type: str) -> bool: - return message_type == MessageClassName.TASK_NOTIFICATION - - -def _message_starts_subagent_tool(message: Any) -> bool: - if not hasattr(message, "content"): - return False - - for block in message.content: - if type(block).__name__ != BlockClassName.TOOL_USE: - continue - if getattr(block, "name", None) == "Agent": - return True - - return False + if message_type == MessageClassName.TASK_NOTIFICATION: + ctx.task_span.end() + ctx.task_span = None + self._task_order = [k for k in self._task_order if k != tool_use_id_str] def _create_client_wrapper_class(original_client_class: Any) -> Any: @@ -675,103 +827,24 @@ async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]: return await self.__client.query(*args, **kwargs) async def receive_response(self) -> AsyncGenerator[Any, None]: - """Wrap receive_response to add tracing. - - Uses start_span context manager which automatically: - - Handles exceptions and logs them as errors - - Sets the span as current so tool calls automatically nest under it - - Manages span lifecycle (start/end) - """ + """Wrap receive_response to add tracing via ContextTracker.""" generator = self.__client.receive_response() - # Determine the initial input - may be updated later if using async generator - initial_input = self.__last_prompt if self.__last_prompt else None - with start_span( name=CLAUDE_AGENT_TASK_SPAN_NAME, span_attributes={"type": SpanTypeAttribute.TASK}, - input=initial_input, + input=self.__last_prompt or None, ) as span: - # If we're capturing async messages, we'll update input after they're consumed - input_needs_update = self.__captured_messages is not None - - final_results: list[dict[str, Any]] = [] - task_events: list[dict[str, Any]] = [] - llm_tracker = LLMSpanTracker(query_start_time=self.__query_start_time) - tool_tracker = ToolSpanTracker() - task_event_span_tracker = TaskEventSpanTracker(span.export(), tool_tracker) - _thread_local.tool_span_tracker = tool_tracker + context_tracker = ContextTracker( + root_span=span, + prompt=self.__last_prompt, + query_start_time=self.__query_start_time, + captured_messages=self.__captured_messages, + ) try: async for message in generator: - # Update input from captured async messages (once, after they're consumed) - if input_needs_update: - captured_input = self.__captured_messages if self.__captured_messages else [] - if captured_input: - span.log(input=captured_input) - input_needs_update = False - - message_type = type(message).__name__ - - if message_type == MessageClassName.ASSISTANT: - if llm_tracker.current_span and tool_tracker.has_active_spans: - active_subagent_tool_use_ids = ( - task_event_span_tracker.active_tool_use_ids - | tool_tracker.pending_task_link_tool_use_ids - ) - tool_tracker.cleanup( - end_time=llm_tracker.get_next_start_time(), - exclude_tool_use_ids=active_subagent_tool_use_ids, - ) - llm_parent_export = task_event_span_tracker.parent_export_for_message( - message, - span.export(), - ) - final_content, extended_existing_span = llm_tracker.start_llm_span( - message, - self.__last_prompt, - final_results, - parent_export=llm_parent_export, - ) - tool_tracker.start_tool_spans(message, llm_tracker.current_span_export) - if final_content: - if ( - extended_existing_span - and final_results - and final_results[-1].get("role") == "assistant" - ): - final_results[-1] = final_content - else: - final_results.append(final_content) - elif message_type == MessageClassName.USER: - tool_tracker.finish_tool_spans(message) - has_tool_results = False - if hasattr(message, "content"): - has_tool_results = any( - type(block).__name__ == BlockClassName.TOOL_RESULT for block in message.content - ) - content = _serialize_content_blocks(message.content) - final_results.append({"content": content, "role": "user"}) - if has_tool_results: - llm_tracker.mark_next_llm_start() - elif message_type == MessageClassName.RESULT: - if hasattr(message, "usage"): - usage_metrics = _extract_usage_from_result_message(message) - llm_tracker.log_usage(usage_metrics) - - result_metadata = { - k: v - for k, v in { - "num_turns": getattr(message, "num_turns", None), - "session_id": getattr(message, "session_id", None), - }.items() - if v is not None - } - span.log(metadata=result_metadata) - elif message_type in SYSTEM_MESSAGE_TYPES: - task_event_span_tracker.process(message) - task_events.append(_serialize_system_message(message)) - + context_tracker.add(message) yield message except asyncio.CancelledError: # The CancelledError may come from the subprocess transport @@ -780,19 +853,12 @@ async def receive_response(self) -> AsyncGenerator[Any, None]: # the response stream ends cleanly. If the caller genuinely # cancelled the task, they still have pending cancellation # requests that will fire at their next await point. - if final_results: - span.log(output=final_results[-1]) + context_tracker.log_output() else: - if final_results: - span.log(output=final_results[-1]) + context_tracker.log_output() finally: - if task_events: - span.log(metadata={"task_events": task_events}) - task_event_span_tracker.cleanup() - tool_tracker.cleanup() - llm_tracker.cleanup() - if hasattr(_thread_local, "tool_span_tracker"): - delattr(_thread_local, "tool_span_tracker") + context_tracker.log_tasks() + context_tracker.cleanup() async def __aenter__(self) -> "WrappedClaudeSDKClient": await self.__client.__aenter__() @@ -817,9 +883,7 @@ def _create_llm_span_for_messages( - final_content: The final message content to add to conversation history - span: The LLM span object (for logging metrics later) - Automatically nests under the current span (TASK span from receive_response). - - Note: This is called from within a catch_exceptions block, so errors won't break user code. + Called by ContextTracker._start_or_merge_llm_span with an explicit parent export. """ if not messages: return None, None diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting.json b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting.json new file mode 100644 index 00000000..c13cb624 --- /dev/null +++ b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting.json @@ -0,0 +1,550 @@ +{ + "cassette_name": "test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting", + "events": [ + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "request": { + "hooks": null, + "subtype": "initialize" + }, + "request_id": "req_1_test_concurrent", + "type": "control_request" + } + } + }, + { + "op": "read", + "payload": { + "response": { + "request_id": "req_1_test_concurrent", + "response": { + "account": { + "apiKeySource": "ANTHROPIC_API_KEY", + "tokenSource": "none" + }, + "agents": [], + "available_output_styles": [], + "commands": [], + "models": [] + }, + "subtype": "success" + }, + "type": "control_response" + } + }, + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "message": { + "content": "Run three tasks.", + "role": "user" + }, + "parent_tool_use_id": null, + "session_id": "default", + "type": "user" + } + } + }, + { + "op": "read", + "payload": { + "agents": [ + "general-purpose" + ], + "apiKeySource": "ANTHROPIC_API_KEY", + "claude_code_version": "2.1.71", + "cwd": "", + "fast_mode_state": "off", + "mcp_servers": [], + "model": "claude-haiku-4-5-20251001", + "output_style": "default", + "permissionMode": "bypassPermissions", + "plugins": [], + "session_id": "session-concurrent", + "skill_sets": [], + "subtype": "init", + "type": "system" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_agent_a", + "input": { + "description": "Task A", + "subagent_type": "general-purpose" + }, + "name": "Agent", + "type": "tool_use" + }, + { + "id": "toolu_agent_b", + "input": { + "description": "Task B", + "subagent_type": "general-purpose" + }, + "name": "Agent", + "type": "tool_use" + }, + { + "id": "toolu_agent_c", + "input": { + "description": "Task C", + "subagent_type": "general-purpose" + }, + "name": "Agent", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": null, + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "description": "Task A", + "session_id": "session-concurrent", + "subtype": "task_started", + "task_id": "task_a", + "task_type": "local_agent", + "tool_use_id": "toolu_agent_a", + "type": "system", + "uuid": "uuid-A-start" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "Prompt for A.", + "type": "text" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_a", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "description": "Task B", + "session_id": "session-concurrent", + "subtype": "task_started", + "task_id": "task_b", + "task_type": "local_agent", + "tool_use_id": "toolu_agent_b", + "type": "system", + "uuid": "uuid-B-start" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "Prompt for B.", + "type": "text" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_b", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "description": "Task C", + "session_id": "session-concurrent", + "subtype": "task_started", + "task_id": "task_c", + "task_type": "local_agent", + "tool_use_id": "toolu_agent_c", + "type": "system", + "uuid": "uuid-C-start" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "Prompt for C.", + "type": "text" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_c", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_tool_a1", + "input": { + "command": "echo a1" + }, + "name": "Bash", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_a", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_tool_b1", + "input": { + "command": "echo b1" + }, + "name": "Bash", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_b", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_tool_c1", + "input": { + "q": "c1" + }, + "name": "mcp__server__remote_tool", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_c", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "a1-output", + "tool_use_id": "toolu_tool_a1", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_a", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "b1-output", + "tool_use_id": "toolu_tool_b1", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_b", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "c1-output", + "tool_use_id": "toolu_tool_c1", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_c", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_tool_a2", + "input": { + "file_path": "/tmp/a.txt" + }, + "name": "Read", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_a", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_tool_b2", + "input": { + "file_path": "/tmp/b.txt" + }, + "name": "Read", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_b", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_tool_c2", + "input": { + "file_path": "/tmp/c.txt" + }, + "name": "Read", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_c", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "a2-output", + "tool_use_id": "toolu_tool_a2", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_a", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "b2-output", + "tool_use_id": "toolu_tool_b2", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_b", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "c2-output", + "tool_use_id": "toolu_tool_c2", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_c", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "output_file": "", + "session_id": "session-concurrent", + "status": "completed", + "subtype": "task_notification", + "summary": "A done", + "task_id": "task_a", + "tool_use_id": "toolu_agent_a", + "type": "system", + "usage": { + "duration_ms": 500, + "tool_uses": 2, + "total_tokens": 100 + }, + "uuid": "uuid-A-done" + } + }, + { + "op": "read", + "payload": { + "output_file": "", + "session_id": "session-concurrent", + "status": "completed", + "subtype": "task_notification", + "summary": "B done", + "task_id": "task_b", + "tool_use_id": "toolu_agent_b", + "type": "system", + "usage": { + "duration_ms": 500, + "tool_uses": 2, + "total_tokens": 100 + }, + "uuid": "uuid-B-done" + } + }, + { + "op": "read", + "payload": { + "output_file": "", + "session_id": "session-concurrent", + "status": "completed", + "subtype": "task_notification", + "summary": "C done", + "task_id": "task_c", + "tool_use_id": "toolu_agent_c", + "type": "system", + "usage": { + "duration_ms": 500, + "tool_uses": 2, + "total_tokens": 100 + }, + "uuid": "uuid-C-done" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "A complete", + "tool_use_id": "toolu_agent_a", + "type": "tool_result" + }, + { + "content": "B complete", + "tool_use_id": "toolu_agent_b", + "type": "tool_result" + }, + { + "content": "C complete", + "tool_use_id": "toolu_agent_c", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": null, + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "Done.", + "type": "text" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": null, + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "duration_api_ms": 3000, + "duration_ms": 5000, + "fast_mode_state": "off", + "is_error": false, + "num_turns": 3, + "permission_denials": [], + "result": "Done.", + "session_id": "session-concurrent", + "stop_reason": "end_turn", + "subtype": "success", + "total_cost_usd": 0.001, + "type": "result", + "usage": { + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "input_tokens": 200, + "output_tokens": 50, + "service_tier": "standard", + "speed": "standard" + }, + "uuid": "uuid-result" + } + } + ], + "sdk_version": "0.1.48" +} diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting_sdk_0_1_10.json b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting_sdk_0_1_10.json new file mode 100644 index 00000000..0ed4f710 --- /dev/null +++ b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting_sdk_0_1_10.json @@ -0,0 +1,186 @@ +{ + "cassette_name": "test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting_sdk_0_1_10", + "events": [ + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "request": { + "hooks": null, + "subtype": "initialize" + }, + "request_id": "req_1_9d03d2d5", + "type": "control_request" + } + } + }, + { + "op": "read", + "payload": { + "response": { + "request_id": "req_1_9d03d2d5", + "response": { + "account": { + "apiKeySource": "ANTHROPIC_API_KEY", + "tokenSource": "none" + }, + "available_output_styles": [], + "commands": [], + "models": [] + }, + "subtype": "success" + }, + "type": "control_response" + } + }, + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "message": { + "content": "Run three tasks.", + "role": "user" + }, + "parent_tool_use_id": null, + "session_id": "default", + "type": "user" + } + } + }, + { + "op": "read", + "payload": { + "agents": [ + "general-purpose", + "statusline-setup", + "Explore", + "Plan" + ], + "apiKeySource": "ANTHROPIC_API_KEY", + "claude_code_version": "2.0.53", + "cwd": "", + "mcp_servers": [], + "model": "claude-haiku-4-5-20251001", + "output_style": "default", + "permissionMode": "bypassPermissions", + "plugins": [], + "session_id": "b604680d-6581-44d7-a3af-a2ed91069472", + "skills": [], + "slash_commands": [ + "compact", + "context", + "cost", + "init", + "pr-comments", + "release-notes", + "todos", + "review", + "security-review" + ], + "subtype": "init", + "tools": [ + "Task", + "Bash", + "Glob", + "Grep", + "ExitPlanMode", + "Read", + "Edit", + "Write", + "NotebookEdit", + "WebFetch", + "TodoWrite", + "WebSearch", + "BashOutput", + "KillShell", + "Skill", + "SlashCommand", + "EnterPlanMode" + ], + "type": "system", + "uuid": "c865727f-7d34-4507-b61f-62a2783275a9" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "I'd be happy to help you run three tasks! However, I need more information about what tasks you'd like me to perform. Could you please specify:\n\n1. **Task 1**: What would you like me to do?\n2. **Task 2**: What would you like me to do?\n3. **Task 3**: What would you like me to do?\n\nFor example, you could ask me to:\n- Search through code files\n- Read or edit specific files\n- Run bash commands\n- Create or modify code\n- Analyze documentation\n- Or anything else you need help with\n\nPlease provide details about each task you'd like me to complete.", + "type": "text" + } + ], + "context_management": null, + "id": "msg_01V8F4DzGofweXoRuffhD7US", + "model": "claude-haiku-4-5-20251001", + "role": "assistant", + "stop_reason": null, + "stop_sequence": null, + "type": "message", + "usage": { + "cache_creation": { + "ephemeral_1h_input_tokens": 0, + "ephemeral_5m_input_tokens": 0 + }, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 13878, + "inference_geo": "not_available", + "input_tokens": 3, + "output_tokens": 2, + "service_tier": "standard" + } + }, + "parent_tool_use_id": null, + "session_id": "b604680d-6581-44d7-a3af-a2ed91069472", + "type": "assistant", + "uuid": "9721c458-3826-477e-9d4c-f999c579eb19" + } + }, + { + "op": "read", + "payload": { + "duration_api_ms": 4085, + "duration_ms": 2145, + "is_error": false, + "modelUsage": { + "claude-haiku-4-5-20251001": { + "cacheCreationInputTokens": 0, + "cacheReadInputTokens": 13878, + "contextWindow": 200000, + "costUSD": 0.0039508, + "inputTokens": 883, + "outputTokens": 336, + "webSearchRequests": 0 + } + }, + "num_turns": 1, + "permission_denials": [], + "result": "I'd be happy to help you run three tasks! However, I need more information about what tasks you'd like me to perform. Could you please specify:\n\n1. **Task 1**: What would you like me to do?\n2. **Task 2**: What would you like me to do?\n3. **Task 3**: What would you like me to do?\n\nFor example, you could ask me to:\n- Search through code files\n- Read or edit specific files\n- Run bash commands\n- Create or modify code\n- Analyze documentation\n- Or anything else you need help with\n\nPlease provide details about each task you'd like me to complete.", + "session_id": "b604680d-6581-44d7-a3af-a2ed91069472", + "subtype": "success", + "total_cost_usd": 0.0039508, + "type": "result", + "usage": { + "cache_creation": { + "ephemeral_1h_input_tokens": 0, + "ephemeral_5m_input_tokens": 0 + }, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 13878, + "input_tokens": 3, + "output_tokens": 145, + "server_tool_use": { + "web_fetch_requests": 0, + "web_search_requests": 0 + }, + "service_tier": "standard" + }, + "uuid": "a318b707-c0e0-456a-8831-9e17587b89d8" + } + } + ], + "sdk_version": "0.1.10" +} diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_interleaved_subagent_tool_output_preserved.json b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_interleaved_subagent_tool_output_preserved.json new file mode 100644 index 00000000..1317e35b --- /dev/null +++ b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_interleaved_subagent_tool_output_preserved.json @@ -0,0 +1,340 @@ +{ + "cassette_name": "test_interleaved_subagent_tool_output_preserved", + "events": [ + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "request": { + "hooks": null, + "subtype": "initialize" + }, + "request_id": "req_1_test_interleave", + "type": "control_request" + } + } + }, + { + "op": "read", + "payload": { + "response": { + "request_id": "req_1_test_interleave", + "response": { + "account": { + "apiKeySource": "ANTHROPIC_API_KEY", + "tokenSource": "none" + }, + "agents": [], + "available_output_styles": [], + "commands": [], + "models": [] + }, + "subtype": "success" + }, + "type": "control_response" + } + }, + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "message": { + "content": "Launch two subagents to process files.", + "role": "user" + }, + "parent_tool_use_id": null, + "session_id": "default", + "type": "user" + } + } + }, + { + "op": "read", + "payload": { + "agents": [ + "general-purpose" + ], + "apiKeySource": "ANTHROPIC_API_KEY", + "claude_code_version": "2.1.71", + "cwd": "", + "fast_mode_state": "off", + "mcp_servers": [], + "model": "claude-haiku-4-5-20251001", + "output_style": "default", + "permissionMode": "bypassPermissions", + "plugins": [], + "session_id": "session-interleave-test", + "skill_sets": [], + "subtype": "init", + "type": "system" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_agent_alpha", + "input": { + "description": "Process alpha file", + "subagent_type": "general-purpose" + }, + "name": "Agent", + "type": "tool_use" + }, + { + "id": "toolu_agent_beta", + "input": { + "description": "Process beta file", + "subagent_type": "general-purpose" + }, + "name": "Agent", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": null, + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "description": "Process alpha file", + "session_id": "session-interleave-test", + "subtype": "task_started", + "task_id": "task_alpha_001", + "task_type": "local_agent", + "tool_use_id": "toolu_agent_alpha", + "type": "system", + "uuid": "uuid-alpha-start" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "Process the alpha file.", + "type": "text" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_alpha", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "description": "Process beta file", + "session_id": "session-interleave-test", + "subtype": "task_started", + "task_id": "task_beta_001", + "task_type": "local_agent", + "tool_use_id": "toolu_agent_beta", + "type": "system", + "uuid": "uuid-beta-start" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "Process the beta file.", + "type": "text" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_beta", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_bash_alpha", + "input": { + "command": "cat /tmp/alpha.txt" + }, + "name": "Bash", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_alpha", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "id": "toolu_read_beta", + "input": { + "file_path": "/tmp/beta.txt" + }, + "name": "Read", + "type": "tool_use" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": "toolu_agent_beta", + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "alpha_file_contents", + "tool_use_id": "toolu_bash_alpha", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_alpha", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "beta_file_contents", + "tool_use_id": "toolu_read_beta", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": "toolu_agent_beta", + "type": "user" + } + }, + { + "op": "read", + "payload": { + "output_file": "", + "session_id": "session-interleave-test", + "status": "completed", + "subtype": "task_notification", + "summary": "Alpha processed", + "task_id": "task_alpha_001", + "tool_use_id": "toolu_agent_alpha", + "type": "system", + "usage": { + "duration_ms": 500, + "tool_uses": 1, + "total_tokens": 100 + }, + "uuid": "uuid-alpha-done" + } + }, + { + "op": "read", + "payload": { + "output_file": "", + "session_id": "session-interleave-test", + "status": "completed", + "subtype": "task_notification", + "summary": "Beta processed", + "task_id": "task_beta_001", + "tool_use_id": "toolu_agent_beta", + "type": "system", + "usage": { + "duration_ms": 600, + "tool_uses": 1, + "total_tokens": 120 + }, + "uuid": "uuid-beta-done" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "content": "alpha processed", + "tool_use_id": "toolu_agent_alpha", + "type": "tool_result" + }, + { + "content": "beta processed", + "tool_use_id": "toolu_agent_beta", + "type": "tool_result" + } + ], + "role": "user" + }, + "parent_tool_use_id": null, + "type": "user" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "Both files have been processed.", + "type": "text" + } + ], + "model": "claude-haiku-4-5-20251001" + }, + "parent_tool_use_id": null, + "type": "assistant" + } + }, + { + "op": "read", + "payload": { + "duration_api_ms": 3000, + "duration_ms": 5000, + "fast_mode_state": "off", + "is_error": false, + "num_turns": 3, + "permission_denials": [], + "result": "Both files have been processed.", + "session_id": "session-interleave-test", + "stop_reason": "end_turn", + "subtype": "success", + "total_cost_usd": 0.001, + "type": "result", + "usage": { + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "input_tokens": 200, + "output_tokens": 50, + "service_tier": "standard", + "speed": "standard" + }, + "uuid": "uuid-result" + } + } + ], + "sdk_version": "0.1.48" +} diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_interleaved_subagent_tool_output_preserved_sdk_0_1_10.json b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_interleaved_subagent_tool_output_preserved_sdk_0_1_10.json new file mode 100644 index 00000000..dacee961 --- /dev/null +++ b/py/src/braintrust/wrappers/claude_agent_sdk/cassettes/test_interleaved_subagent_tool_output_preserved_sdk_0_1_10.json @@ -0,0 +1,186 @@ +{ + "cassette_name": "test_interleaved_subagent_tool_output_preserved_sdk_0_1_10", + "events": [ + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "request": { + "hooks": null, + "subtype": "initialize" + }, + "request_id": "req_1_5588877a", + "type": "control_request" + } + } + }, + { + "op": "read", + "payload": { + "response": { + "request_id": "req_1_5588877a", + "response": { + "account": { + "apiKeySource": "ANTHROPIC_API_KEY", + "tokenSource": "none" + }, + "available_output_styles": [], + "commands": [], + "models": [] + }, + "subtype": "success" + }, + "type": "control_response" + } + }, + { + "op": "write", + "payload": { + "kind": "json", + "value": { + "message": { + "content": "Launch two subagents to process files.", + "role": "user" + }, + "parent_tool_use_id": null, + "session_id": "default", + "type": "user" + } + } + }, + { + "op": "read", + "payload": { + "agents": [ + "general-purpose", + "statusline-setup", + "Explore", + "Plan" + ], + "apiKeySource": "ANTHROPIC_API_KEY", + "claude_code_version": "2.0.53", + "cwd": "", + "mcp_servers": [], + "model": "claude-haiku-4-5-20251001", + "output_style": "default", + "permissionMode": "bypassPermissions", + "plugins": [], + "session_id": "7828a1aa-bb16-40b9-bc86-51e2d517ff64", + "skills": [], + "slash_commands": [ + "compact", + "context", + "cost", + "init", + "pr-comments", + "release-notes", + "todos", + "review", + "security-review" + ], + "subtype": "init", + "tools": [ + "Task", + "Bash", + "Glob", + "Grep", + "ExitPlanMode", + "Read", + "Edit", + "Write", + "NotebookEdit", + "WebFetch", + "TodoWrite", + "WebSearch", + "BashOutput", + "KillShell", + "Skill", + "SlashCommand", + "EnterPlanMode" + ], + "type": "system", + "uuid": "8bbd422a-4cf7-4325-b7eb-5d9b3f1fbeef" + } + }, + { + "op": "read", + "payload": { + "message": { + "content": [ + { + "text": "I'd be happy to help you launch two subagents to process files! However, I need more information about what you'd like them to do.\n\nCould you please clarify:\n\n1. **What type of processing do you need?**\n - Exploring/searching the codebase?\n - Analyzing code for specific patterns?\n - Reading and summarizing file contents?\n - Something else?\n\n2. **What files or directories should they work with?**\n - Specific file paths or patterns?\n - Which directories to focus on?\n\n3. **What should the output be?**\n - A summary of findings?\n - Specific information extracted?\n - Code changes suggested?\n\n4. **Which agent types would be most appropriate?**\n - `general-purpose` - for complex multi-step tasks\n - `Explore` - for quickly finding files and understanding code patterns\n - `Plan` - for exploring and planning implementation\n\nOnce you provide these details, I can launch two subagents in parallel to handle your file processing tasks efficiently!", + "type": "text" + } + ], + "context_management": null, + "id": "msg_01EmuZbDHBmcwyS4nHu8ABTu", + "model": "claude-haiku-4-5-20251001", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": null, + "type": "message", + "usage": { + "cache_creation": { + "ephemeral_1h_input_tokens": 0, + "ephemeral_5m_input_tokens": 329 + }, + "cache_creation_input_tokens": 329, + "cache_read_input_tokens": 13554, + "inference_geo": "not_available", + "input_tokens": 3, + "output_tokens": 239, + "service_tier": "standard" + } + }, + "parent_tool_use_id": null, + "session_id": "7828a1aa-bb16-40b9-bc86-51e2d517ff64", + "type": "assistant", + "uuid": "f882767c-9f15-4a10-9d34-acefa6219dd5" + } + }, + { + "op": "read", + "payload": { + "duration_api_ms": 6524, + "duration_ms": 4413, + "is_error": false, + "modelUsage": { + "claude-haiku-4-5-20251001": { + "cacheCreationInputTokens": 329, + "cacheReadInputTokens": 13554, + "contextWindow": 200000, + "costUSD": 0.00466965, + "inputTokens": 938, + "outputTokens": 393, + "webSearchRequests": 0 + } + }, + "num_turns": 1, + "permission_denials": [], + "result": "I'd be happy to help you launch two subagents to process files! However, I need more information about what you'd like them to do.\n\nCould you please clarify:\n\n1. **What type of processing do you need?**\n - Exploring/searching the codebase?\n - Analyzing code for specific patterns?\n - Reading and summarizing file contents?\n - Something else?\n\n2. **What files or directories should they work with?**\n - Specific file paths or patterns?\n - Which directories to focus on?\n\n3. **What should the output be?**\n - A summary of findings?\n - Specific information extracted?\n - Code changes suggested?\n\n4. **Which agent types would be most appropriate?**\n - `general-purpose` - for complex multi-step tasks\n - `Explore` - for quickly finding files and understanding code patterns\n - `Plan` - for exploring and planning implementation\n\nOnce you provide these details, I can launch two subagents in parallel to handle your file processing tasks efficiently!", + "session_id": "7828a1aa-bb16-40b9-bc86-51e2d517ff64", + "subtype": "success", + "total_cost_usd": 0.00466965, + "type": "result", + "usage": { + "cache_creation": { + "ephemeral_1h_input_tokens": 0, + "ephemeral_5m_input_tokens": 329 + }, + "cache_creation_input_tokens": 329, + "cache_read_input_tokens": 13554, + "input_tokens": 3, + "output_tokens": 239, + "server_tool_use": { + "web_fetch_requests": 0, + "web_search_requests": 0 + }, + "service_tier": "standard" + }, + "uuid": "2bb0e757-b81f-4aa6-a904-13ae1e3bd1a4" + } + } + ], + "sdk_version": "0.1.10" +} diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py index eb12fa3d..44cb8426 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py @@ -61,7 +61,6 @@ def memory_logger(): def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = False): original_client = claude_agent_sdk.ClaudeSDKClient original_tool_class = claude_agent_sdk.SdkMcpTool - original_tool_fn = claude_agent_sdk.tool if wrap_client: claude_agent_sdk.ClaudeSDKClient = _create_client_wrapper_class(original_client) @@ -73,7 +72,6 @@ def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = Fa finally: claude_agent_sdk.ClaudeSDKClient = original_client claude_agent_sdk.SdkMcpTool = original_tool_class - claude_agent_sdk.tool = original_tool_fn @pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") @@ -222,6 +220,14 @@ def _assert_llm_spans_have_time_to_first_token(llm_spans: list[dict[str, Any]]) assert llm_span["metrics"]["time_to_first_token"] >= 0 +def _sdk_cassette_name(base: str, *, min_version: str) -> str: + """Return base cassette name for SDK >= min_version, else a version-specific variant.""" + if _sdk_version_at_least(min_version): + return base + sdk_ver = getattr(claude_agent_sdk, "__version__", "0").replace(".", "_") + return f"{base}_sdk_{sdk_ver}" + + def _sdk_version_at_least(version: str) -> bool: if not CLAUDE_SDK_AVAILABLE: return False @@ -885,7 +891,7 @@ async def test_relay_user_messages_between_parallel_agent_calls_do_not_split_llm async def test_agent_tool_spans_encapsulate_child_task_spans(memory_logger): """Agent TOOL spans must end after their child TASK spans, not before. - The mid-stream tool_tracker.cleanup() in the AssistantMessage handler must + The mid-stream tool_tracker.cleanup_context() in the AssistantMessage handler must not close Agent TOOL spans that still have active child TASK spans. Those Agent TOOL spans should only close when their ToolResult arrives. """ @@ -1429,7 +1435,7 @@ def test_tool_span_tracker_cleanup_closes_unmatched_spans(memory_logger): AssistantMessage(content=[ToolUseBlock(id="call-dangling", name="weather", input={"city": "Toronto"})]), llm_span.export(), ) - tracker.cleanup() + tracker.cleanup_all() llm_span.end() spans = memory_logger.pop() @@ -1711,7 +1717,7 @@ async def calculator_handler(args): ) finally: _clear_tool_span_tracker() - tracker.cleanup() + tracker.cleanup_all() llm_span.end() assert result == {"content": [{"type": "text", "text": "42"}]} @@ -1772,7 +1778,7 @@ async def calculator_handler(args): ) finally: _clear_tool_span_tracker() - tracker.cleanup() + tracker.cleanup_all() llm_span.end() spans = memory_logger.pop() @@ -1810,14 +1816,12 @@ async def test_setup_claude_agent_sdk_repro_import_before_setup(memory_logger, m assert not memory_logger.pop() original_client = claude_agent_sdk.ClaudeSDKClient original_tool_class = claude_agent_sdk.SdkMcpTool - original_tool_fn = claude_agent_sdk.tool consumer_module_name = "test_issue7_repro_module" consumer_module = types.ModuleType(consumer_module_name) consumer_module.ClaudeSDKClient = original_client consumer_module.ClaudeAgentOptions = claude_agent_sdk.ClaudeAgentOptions consumer_module.SdkMcpTool = original_tool_class - consumer_module.tool = original_tool_fn monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module) loop_errors = [] @@ -1827,9 +1831,7 @@ async def test_setup_claude_agent_sdk_repro_import_before_setup(memory_logger, m assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY) assert getattr(consumer_module, "ClaudeSDKClient") is not original_client assert getattr(consumer_module, "SdkMcpTool") is not original_tool_class - assert getattr(consumer_module, "tool") is not original_tool_fn assert claude_agent_sdk.SdkMcpTool is not original_tool_class - assert claude_agent_sdk.tool is not original_tool_fn async def main() -> None: loop = asyncio.get_running_loop() @@ -1862,3 +1864,553 @@ async def main() -> None: assert len(task_spans) == 1 assert task_spans[0]["span_attributes"]["name"] == "Claude Agent" assert task_spans[0]["input"] == "Say hi" + + +@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") +@pytest.mark.asyncio +async def test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting(memory_logger): + """Concurrent subagent LLM spans must run in parallel, not be serialized into a single + sequential chain — and every tool span must be parented to its own subagent's LLM span + with output preserved. + + Three subagents each perform two interleaved tool rounds: + LLM(A:Bash) → LLM(B:Bash) → LLM(C:MCP tool) → result(A) → result(B) → result(C) + LLM(A:Read) → LLM(B:Read) → LLM(C:Read) → result(A) → result(B) → result(C) + + Verifies: + - Each subagent gets its own LLM spans (not shared with other subagents) + - LLM spans from different subagents overlap in time (parallel execution) + - Tool spans are parented to the correct subagent's LLM span + - Tool output is preserved despite cross-subagent message interleaving + """ + assert not memory_logger.pop() + + subagents = [ + {"label": "A", "agent_id": "toolu_agent_a", "task_id": "task_a"}, + {"label": "B", "agent_id": "toolu_agent_b", "task_id": "task_b"}, + {"label": "C", "agent_id": "toolu_agent_c", "task_id": "task_c"}, + ] + round1_tools = [ + {"id": "toolu_tool_a1", "name": "Bash", "agent_id": "toolu_agent_a", "result": "a1-output"}, + {"id": "toolu_tool_b1", "name": "Bash", "agent_id": "toolu_agent_b", "result": "b1-output"}, + { + "id": "toolu_tool_c1", + "name": "mcp__server__remote_tool", + "agent_id": "toolu_agent_c", + "result": "c1-output", + }, + ] + round2_tools = [ + {"id": "toolu_tool_a2", "name": "Read", "agent_id": "toolu_agent_a", "result": "a2-output"}, + {"id": "toolu_tool_b2", "name": "Read", "agent_id": "toolu_agent_b", "result": "b2-output"}, + {"id": "toolu_tool_c2", "name": "Read", "agent_id": "toolu_agent_c", "result": "c2-output"}, + ] + all_tools = round1_tools + round2_tools + + with _patched_claude_sdk(wrap_client=True): + options = claude_agent_sdk.ClaudeAgentOptions( + model=TEST_MODEL, + permission_mode="bypassPermissions", + ) + transport = make_cassette_transport( + cassette_name=_sdk_cassette_name( + "test_concurrent_subagents_produce_parallel_llm_spans_with_correct_parenting", + min_version="0.1.11", + ), + prompt="", + options=options, + ) + + async with claude_agent_sdk.ClaudeSDKClient(options=options, transport=transport) as client: + await client.query("Run three tasks.") + async for message in client.receive_response(): + if type(message).__name__ == "ResultMessage": + break + + spans = memory_logger.pop() + task_spans = _find_spans_by_type(spans, SpanTypeAttribute.TASK) + llm_spans = _find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = _find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + all_tools = round1_tools + round2_tools + + # --- 1. Root TASK span exists --- + _find_span_by_name(task_spans, "Claude Agent") + + if not _sdk_version_at_least("0.1.11"): + # SDK 0.1.10 replays a limited cassette (single assistant + result); + # only assert the root task span was produced. + return + + # --- 2. All subagent TASK spans exist --- + subagent_task_by_label: dict[str, dict[str, Any]] = {} + for sa in subagents: + subagent_task_by_label[sa["label"]] = _find_span_by_name(task_spans, f"Task {sa['label']}") + + task_id_by_span = {t["span_id"]: label for label, t in subagent_task_by_label.items()} + + # --- 3. Every tool span has output --- + non_agent_tools = [s for s in tool_spans if s["span_attributes"]["name"] != "Agent"] + tools_without_output = [s for s in non_agent_tools if s.get("output") is None] + assert not tools_without_output, ( + f"{len(tools_without_output)} of {len(non_agent_tools)} tool spans lost their output. " + f"Missing: {[s['span_attributes']['name'] + '(' + s.get('metadata', {}).get('gen_ai.tool.call.id', '?') + ')' for s in tools_without_output]}" + ) + + # --- 4. Tool spans are parented to the correct subagent's LLM span --- + agent_id_to_label = {sa["agent_id"]: sa["label"] for sa in subagents} + tool_id_to_label = {t["id"]: agent_id_to_label[t["agent_id"]] for t in all_tools} + + for tool in non_agent_tools: + tool_call_id = tool.get("metadata", {}).get("gen_ai.tool.call.id", "") + expected_label = tool_id_to_label.get(tool_call_id) + if expected_label is None: + continue + + parent_llm = next((s for s in llm_spans if s["span_id"] == tool["span_parents"][0]), None) + assert parent_llm is not None, f"Tool {tool_call_id} has no parent LLM span" + + llm_task_parent_id = parent_llm["span_parents"][0] + actual_label = task_id_by_span.get(llm_task_parent_id) + assert actual_label == expected_label, ( + f"Tool {tool_call_id} should be under subagent {expected_label}, got {actual_label}" + ) + + # --- 5. Correct tool output content --- + for t in all_tools: + span = next(s for s in tool_spans if s.get("metadata", {}).get("gen_ai.tool.call.id") == t["id"]) + assert span["output"]["content"] == t["result"] + + # MCP tool name should be parsed + mcp_span = next(s for s in tool_spans if s.get("metadata", {}).get("gen_ai.tool.call.id") == "toolu_tool_c1") + assert mcp_span["span_attributes"]["name"] == "remote_tool" + assert mcp_span["metadata"].get("mcp.server") == "server" + + # --- 6. Scale check --- + assert len(non_agent_tools) == 6 + assert len(llm_spans) >= 7 + assert len(task_spans) == 4 + + # --- 7. LLM spans from different subagents overlap (not serialized) --- + subagent_llm_spans: dict[str, list[dict[str, Any]]] = {sa["label"]: [] for sa in subagents} + for llm_span in llm_spans: + label = task_id_by_span.get(llm_span["span_parents"][0]) + if label: + subagent_llm_spans[label].append(llm_span) + + for label, llms in subagent_llm_spans.items(): + assert len(llms) == 2, f"Expected 2 LLM spans for subagent {label} (one per tool round), got {len(llms)}" + + a_first = min(subagent_llm_spans["A"], key=lambda s: s["metrics"]["start"]) + b_first = min(subagent_llm_spans["B"], key=lambda s: s["metrics"]["start"]) + assert a_first["metrics"]["end"] > b_first["metrics"]["start"], ( + f"Subagent A's first LLM span should overlap with B's (not be truncated). " + f"A end={a_first['metrics']['end']}, B start={b_first['metrics']['start']}" + ) + + # --- 8. Tool spans fit within their parent LLM span --- + for tool in non_agent_tools: + parent_llm = next((s for s in llm_spans if s["span_id"] == tool["span_parents"][0]), None) + if parent_llm and "end" in parent_llm.get("metrics", {}): + assert tool["metrics"]["start"] >= parent_llm["metrics"]["start"], "Tool starts before parent LLM" + assert tool["metrics"]["end"] <= parent_llm["metrics"]["end"], "Tool extends past parent LLM" + + +@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") +@pytest.mark.asyncio +async def test_interleaved_subagent_tool_spans_preserve_output(memory_logger): + """Cassette-backed test: tool spans from one subagent must retain their + output when another subagent's AssistantMessage arrives before the first + subagent's ToolResultBlock. + + The cassette replays a realistic SDK message stream where: + 1. Orchestrator launches subagent-alpha and subagent-beta + 2. Alpha's LLM turn emits a Bash tool call + 3. Beta's LLM turn emits a Read tool call BEFORE alpha's tool result + 4. Alpha's tool result arrives + 5. Beta's tool result arrives + + Expected: Both Bash and Read tool spans should have their output recorded. + Bug: cleanup() in receive_response force-ends alpha's Bash tool span when + beta's AssistantMessage arrives, so alpha's ToolResultBlock is silently + skipped and its output is lost. + """ + assert not memory_logger.pop() + + with _patched_claude_sdk(wrap_client=True): + options = claude_agent_sdk.ClaudeAgentOptions( + model=TEST_MODEL, + permission_mode="bypassPermissions", + ) + transport = make_cassette_transport( + cassette_name="test_interleaved_subagent_tool_output_preserved", + prompt="", + options=options, + ) + + async with claude_agent_sdk.ClaudeSDKClient(options=options, transport=transport) as client: + await client.query("Launch two subagents to process files.") + async for message in client.receive_response(): + if type(message).__name__ == "ResultMessage": + break + + spans = memory_logger.pop() + tool_spans = _find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + bash_span = _find_span_by_name(tool_spans, "Bash") + read_span = _find_span_by_name(tool_spans, "Read") + + # Both tool spans should have their output recorded + assert bash_span.get("output") is not None, ( + "Bash tool span output was lost — the cleanup force-ended it before its ToolResultBlock arrived" + ) + assert bash_span["output"]["content"] == "alpha_file_contents" + + assert read_span.get("output") is not None, ( + "Read tool span output was lost — the cleanup force-ended it before its ToolResultBlock arrived" + ) + assert read_span["output"]["content"] == "beta_file_contents" + + +@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed") +@pytest.mark.asyncio +async def test_interleaved_subagent_tool_spans_parent_to_correct_llm(memory_logger): + """Cassette-backed test: tool spans from interleaved subagents must be + parented to the LLM span from their own subagent, not the most recent + LLM span from any subagent. + + Uses the same interleaved cassette to verify that even when messages from + different subagents interleave on the single message stream, each tool span + references the correct LLM parent via parent_tool_use_id routing. + """ + assert not memory_logger.pop() + + with _patched_claude_sdk(wrap_client=True): + options = claude_agent_sdk.ClaudeAgentOptions( + model=TEST_MODEL, + permission_mode="bypassPermissions", + ) + transport = make_cassette_transport( + cassette_name=_sdk_cassette_name( + "test_interleaved_subagent_tool_output_preserved", + min_version="0.1.11", + ), + prompt="", + options=options, + ) + + async with claude_agent_sdk.ClaudeSDKClient(options=options, transport=transport) as client: + await client.query("Launch two subagents to process files.") + async for message in client.receive_response(): + if type(message).__name__ == "ResultMessage": + break + + spans = memory_logger.pop() + llm_spans = _find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = _find_spans_by_type(spans, SpanTypeAttribute.TOOL) + task_spans = _find_spans_by_type(spans, SpanTypeAttribute.TASK) + + _find_span_by_name(task_spans, "Claude Agent") + + if not _sdk_version_at_least("0.1.11"): + # SDK 0.1.10 replays a limited cassette; only assert root task span. + return + + alpha_task = _find_span_by_name(task_spans, "Process alpha file") + beta_task = _find_span_by_name(task_spans, "Process beta file") + + bash_span = _find_span_by_name(tool_spans, "Bash") + read_span = _find_span_by_name(tool_spans, "Read") + + # Find each tool's parent LLM span + bash_parent_llm_id = bash_span["span_parents"][0] + read_parent_llm_id = read_span["span_parents"][0] + + bash_parent_llm = next(s for s in llm_spans if s["span_id"] == bash_parent_llm_id) + read_parent_llm = next(s for s in llm_spans if s["span_id"] == read_parent_llm_id) + + # Bash's parent LLM should be under alpha's task + assert alpha_task["span_id"] in bash_parent_llm["span_parents"], ( + f"Bash's parent LLM span should be under alpha task, but its parents are {bash_parent_llm['span_parents']}" + ) + + # Read's parent LLM should be under beta's task + assert beta_task["span_id"] in read_parent_llm["span_parents"], ( + f"Read's parent LLM span should be under beta task, but its parents are {read_parent_llm['span_parents']}" + ) + + # The two tool spans should have DIFFERENT LLM parents (not shared) + assert bash_parent_llm_id != read_parent_llm_id, ( + "Tool spans from different subagents should be parented to different LLM spans" + ) + + +@pytest.mark.asyncio +async def test_concurrent_subagent_tool_output_not_silently_dropped(memory_logger): + """cleanup() scoped to a different subagent must not end tool spans from + the first subagent. When only_parent_tool_use_id targets beta's context, + alpha's Bash tool span must survive so its ToolResultBlock is recorded. + """ + assert not memory_logger.pop() + + tracker = ToolSpanTracker() + + with start_span(name="Claude Agent", type=SpanTypeAttribute.TASK) as task_span: + # Alpha's LLM span and Bash tool span (parent_tool_use_id="call-alpha") + llm_span = start_span( + name="anthropic.messages.create", + type=SpanTypeAttribute.LLM, + parent=task_span.export(), + ) + tracker.start_tool_spans( + AssistantMessage( + content=[ToolUseBlock(id="bash-1", name="Bash", input={"command": "echo hello"})], + parent_tool_use_id="call-alpha", + ), + llm_span.export(), + ) + + assert tracker.has_active_spans, "Tool span should be active after start_tool_spans" + + # Cleanup triggered by beta's AssistantMessage — scoped to beta's context + tracker.cleanup_context("call-beta") + + # Alpha's tool span should still be active + assert tracker.has_active_spans, "cleanup_context('call-beta') should not end alpha's tool span" + + # Alpha's ToolResultBlock arrives and should be recorded + tracker.finish_tool_spans( + UserMessage(content=[ToolResultBlock(tool_use_id="bash-1", content=[TextBlock("hello")])]) + ) + llm_span.end() + + spans = memory_logger.pop() + bash_span = _find_span_by_name( + [s for s in spans if s.get("span_attributes", {}).get("type") == SpanTypeAttribute.TOOL], + "Bash", + ) + + assert bash_span.get("output") is not None, ( + "Tool result was silently dropped. cleanup() scoped to a different subagent " + "should not have ended this tool span." + ) + assert bash_span["output"]["content"] == "hello" + + +def test_tool_span_tracker_cleanup_preserves_cross_subagent_spans(memory_logger): + """cleanup(only_parent_tool_use_id=...) should not end tool spans that + belong to a different subagent context. + + Alpha starts a Bash tool span. A cleanup scoped to beta's context fires. + Alpha's span must survive so its ToolResultBlock is recorded. + """ + assert not memory_logger.pop() + + tracker = ToolSpanTracker() + + with start_span(name="Claude Agent", type=SpanTypeAttribute.TASK) as task_span: + # Alpha's LLM span and tool span + alpha_llm = start_span( + name="anthropic.messages.create", + type=SpanTypeAttribute.LLM, + parent=task_span.export(), + ) + tracker.start_tool_spans( + AssistantMessage( + content=[ToolUseBlock(id="bash-alpha", name="Bash", input={"command": "echo alpha"})], + parent_tool_use_id="call-alpha", + ), + alpha_llm.export(), + ) + + # Cleanup triggered by beta's AssistantMessage — scoped to beta + tracker.cleanup_context("call-beta") + + # Alpha's span should still be active + assert tracker.has_active_spans, "Alpha's tool span should survive beta-scoped cleanup" + + # Alpha's tool result arrives + tracker.finish_tool_spans( + UserMessage(content=[ToolResultBlock(tool_use_id="bash-alpha", content=[TextBlock("alpha output")])]) + ) + alpha_llm.end() + + spans = memory_logger.pop() + bash_spans = [s for s in spans if s.get("span_attributes", {}).get("name") == "Bash"] + assert len(bash_spans) == 1 + bash_span = bash_spans[0] + + assert bash_span.get("output") is not None, ( + "Tool span output was lost because cleanup() ended a span from a different subagent context." + ) + assert bash_span["output"]["content"] == "alpha output" + + +@pytest.mark.asyncio +async def test_identical_concurrent_tool_calls_from_sibling_subagents_disambiguated(memory_logger): + """When two sibling subagents invoke the same tool with the same args, + each handler must acquire the tool span belonging to its own subagent + (matched by FIFO dispatch order) rather than stealing the other's span. + """ + assert not memory_logger.pop() + + wrapped_tool_class = _create_tool_wrapper_class(_make_fake_sdk_mcp_tool_class()) + + async def echo_handler(args): + nested = start_span(name=f"nested_{args['_tag']}") + nested.log(input=args) + nested.end() + return {"content": [{"type": "text", "text": args["_tag"]}]} + + echo_tool = wrapped_tool_class( + name="echo", + description="Echo a message", + input_schema={"type": "object"}, + handler=echo_handler, + ) + + tracker = ToolSpanTracker() + shared_input = {"message": "hello", "_tag": "alpha"} + + with start_span(name="Claude Agent", type=SpanTypeAttribute.TASK) as task_span: + # Subagent alpha's LLM span and tool span + alpha_llm = start_span( + name="anthropic.messages.create", + type=SpanTypeAttribute.LLM, + parent=task_span.export(), + ) + tracker.start_tool_spans( + AssistantMessage( + content=[ToolUseBlock(id="echo-alpha", name="echo", input=shared_input)], + parent_tool_use_id="call-alpha", + ), + alpha_llm.export(), + ) + + # Subagent beta's LLM span and tool span — same tool, same input + beta_llm = start_span( + name="anthropic.messages.create", + type=SpanTypeAttribute.LLM, + parent=task_span.export(), + ) + tracker.start_tool_spans( + AssistantMessage( + content=[ToolUseBlock(id="echo-beta", name="echo", input=shared_input)], + parent_tool_use_id="call-beta", + ), + beta_llm.export(), + ) + + _thread_local.tool_span_tracker = tracker + try: + # Handler for alpha fires first (FIFO order matches creation order) + await echo_tool.handler(shared_input) + # Handler for beta fires second + await echo_tool.handler(shared_input) + + tracker.finish_tool_spans( + UserMessage( + content=[ToolResultBlock(tool_use_id="echo-alpha", content=[TextBlock("alpha")])], + parent_tool_use_id="call-alpha", + ) + ) + tracker.finish_tool_spans( + UserMessage( + content=[ToolResultBlock(tool_use_id="echo-beta", content=[TextBlock("beta")])], + parent_tool_use_id="call-beta", + ) + ) + finally: + _clear_tool_span_tracker() + tracker.cleanup_all() + alpha_llm.end() + beta_llm.end() + + spans = memory_logger.pop() + echo_spans = [ + s for s in _find_spans_by_type(spans, SpanTypeAttribute.TOOL) if s["span_attributes"]["name"] == "echo" + ] + assert len(echo_spans) == 2, f"Expected 2 echo tool spans, got {len(echo_spans)}" + + # Identify which span belongs to alpha's and beta's tool call + alpha_echo = [s for s in echo_spans if s.get("metadata", {}).get("gen_ai.tool.call.id") == "echo-alpha"] + beta_echo = [s for s in echo_spans if s.get("metadata", {}).get("gen_ai.tool.call.id") == "echo-beta"] + assert len(alpha_echo) == 1, "Should have exactly one alpha echo span" + assert len(beta_echo) == 1, "Should have exactly one beta echo span" + + # Both handlers receive the same input with _tag="alpha", so both nested + # spans are named "nested_alpha". Find both by filtering. + nested_spans = [s for s in spans if s["span_attributes"]["name"] == "nested_alpha"] + assert len(nested_spans) == 2, f"Expected 2 nested spans, got {len(nested_spans)}" + + # The first handler invocation should nest under the first span (alpha), + # and the second under the second span (beta). + first_nested = nested_spans[0] + assert alpha_echo[0]["span_id"] in first_nested["span_parents"], ( + "First handler's nested span should be parented under alpha's echo tool span, not swapped with beta's." + ) + second_nested = nested_spans[1] + assert beta_echo[0]["span_id"] in second_nested["span_parents"], ( + "Second handler's nested span should be parented under beta's echo tool span, not swapped with alpha's." + ) + + +def test_dispatch_queue_assigns_identical_tool_spans_in_fifo_order(memory_logger): + """ToolSpanTracker.acquire_span_for_handler() should use the dispatch queue + to assign identical (same name + same input) tool spans in FIFO order, + preventing span swaps between sibling subagents. + """ + assert not memory_logger.pop() + + tracker = ToolSpanTracker() + shared_input = {"cmd": "echo hi"} + + with start_span(name="Claude Agent", type=SpanTypeAttribute.TASK) as task_span: + llm_alpha = start_span( + name="anthropic.messages.create", + type=SpanTypeAttribute.LLM, + parent=task_span.export(), + ) + tracker.start_tool_spans( + AssistantMessage( + content=[ToolUseBlock(id="bash-A", name="Bash", input=shared_input)], + parent_tool_use_id="call-alpha", + ), + llm_alpha.export(), + ) + + llm_beta = start_span( + name="anthropic.messages.create", + type=SpanTypeAttribute.LLM, + parent=task_span.export(), + ) + tracker.start_tool_spans( + AssistantMessage( + content=[ToolUseBlock(id="bash-B", name="Bash", input=shared_input)], + parent_tool_use_id="call-beta", + ), + llm_beta.export(), + ) + + # First acquire should return alpha's span (FIFO) + first = tracker.acquire_span_for_handler("Bash", shared_input) + assert first is not None + assert first.tool_use_id == "bash-A", ( + f"First acquire should return alpha's span (bash-A), got {first.tool_use_id}" + ) + + # Second acquire should return beta's span + second = tracker.acquire_span_for_handler("Bash", shared_input) + assert second is not None + assert second.tool_use_id == "bash-B", ( + f"Second acquire should return beta's span (bash-B), got {second.tool_use_id}" + ) + + # Cleanup + first.release() + second.release() + tracker.cleanup_all() + llm_alpha.end() + llm_beta.end() + + memory_logger.pop() # consume spans