From 59a4e324b92339635aaa51d3dd5652ad33a7358f Mon Sep 17 00:00:00 2001 From: Dobes Vandermeer Date: Sun, 22 Feb 2026 20:36:06 -0800 Subject: [PATCH] Add the ability to click through to sub-agent session When debugging agent behavior it is very useful to be able to drill down into the sub-agent activity and what tool calls they had - or even drill down into *their* sub-agent calls, and so on. I find this feature very useful in agent development with sub-agents and sub-sub-agents. I'm not super happy with this implementation. It does basically work, though (at least it does in my main branch with all my customizations in it) --- go/internal/database/client.go | 21 + go/internal/database/fake/client.go | 20 + go/internal/httpserver/handlers/sessions.go | 28 ++ .../httpserver/handlers/sessions_test.go | 56 ++- go/internal/httpserver/server.go | 1 + go/pkg/database/client.go | 3 + .../src/kagent/adk/_agent_executor.py | 8 +- .../kagent/adk/_sub_agent_session_plugin.py | 141 ++++++ .../packages/kagent-adk/src/kagent/adk/cli.py | 15 +- .../kagent/adk/converters/part_converter.py | 21 +- .../kagent-adk/src/kagent/adk/types.py | 12 + .../tests/unittests/test_agent_tool.py | 415 ++++++++++++++++++ ui/src/app/actions/sessions.ts | 11 + .../[name]/function-calls/[id]/page.tsx | 76 ++++ .../[namespace]/[name]/{chat => }/layout.tsx | 0 ui/src/components/chat/AgentCallDisplay.tsx | 18 +- 16 files changed, 827 insertions(+), 19 deletions(-) create mode 100644 python/packages/kagent-adk/src/kagent/adk/_sub_agent_session_plugin.py create mode 100644 python/packages/kagent-adk/tests/unittests/test_agent_tool.py create mode 100644 ui/src/app/agents/[namespace]/[name]/function-calls/[id]/page.tsx rename ui/src/app/agents/[namespace]/[name]/{chat => }/layout.tsx (100%) diff --git a/go/internal/database/client.go b/go/internal/database/client.go index 1d691baed..08b81b051 100644 --- a/go/internal/database/client.go +++ b/go/internal/database/client.go @@ -491,6 +491,27 @@ func (c *clientImpl) DeleteCheckpoint(userID, threadID string) error { }) } +func (c *clientImpl) FindSessionByParentFunctionCallID(functionCallID string) (*dbpkg.Session, error) { + var event dbpkg.Event + // Use a wildcard between the key and value to handle potential whitespace (e.g. ": " vs ":") + searchPattern := fmt.Sprintf("%%\"parent_function_call_id\"%%\"%s\"%%", functionCallID) + err := c.db.Where("data LIKE ?", searchPattern).First(&event).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, gorm.ErrRecordNotFound + } + return nil, fmt.Errorf("failed to find event by parent function call ID: %w", err) + } + + var session dbpkg.Session + err = c.db.Where("id = ?", event.SessionID).First(&session).Error + if err != nil { + return nil, fmt.Errorf("failed to find session for event: %w", err) + } + + return &session, nil +} + // CrewAI methods // StoreCrewAIMemory stores CrewAI agent memory diff --git a/go/internal/database/fake/client.go b/go/internal/database/fake/client.go index 24bd935ca..00fe6a400 100644 --- a/go/internal/database/fake/client.go +++ b/go/internal/database/fake/client.go @@ -741,6 +741,26 @@ func (c *InMemoryFakeClient) ListWrites(userID, threadID, checkpointNS, checkpoi return writes[start:end], nil } +func (c *InMemoryFakeClient) FindSessionByParentFunctionCallID(functionCallID string) (*database.Session, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + searchPattern1 := fmt.Sprintf(`"parent_function_call_id":"%s"`, functionCallID) + searchPattern2 := fmt.Sprintf(`"parent_function_call_id": "%s"`, functionCallID) + + for _, event := range c.events { + if strings.Contains(event.Data, searchPattern1) || strings.Contains(event.Data, searchPattern2) { + for _, session := range c.sessions { + if session.ID == event.SessionID { + return session, nil + } + } + return nil, gorm.ErrRecordNotFound + } + } + return nil, gorm.ErrRecordNotFound +} + // CrewAI methods // StoreCrewAIMemory stores CrewAI agent memory diff --git a/go/internal/httpserver/handlers/sessions.go b/go/internal/httpserver/handlers/sessions.go index e93b16bef..e81014129 100644 --- a/go/internal/httpserver/handlers/sessions.go +++ b/go/internal/httpserver/handlers/sessions.go @@ -1,6 +1,7 @@ package handlers import ( + stdErrors "errors" "fmt" "net/http" "strconv" @@ -10,6 +11,7 @@ import ( "github.com/kagent-dev/kagent/go/internal/utils" "github.com/kagent-dev/kagent/go/pkg/client/api" "github.com/kagent-dev/kagent/go/pkg/database" + "gorm.io/gorm" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" "trpc.group/trpc-go/trpc-a2a-go/protocol" ) @@ -395,6 +397,32 @@ func (h *SessionsHandler) HandleAddEventToSession(w ErrorResponseWriter, r *http RespondWithJSON(w, http.StatusCreated, data) } +func (h *SessionsHandler) HandleFindSessionByFunctionCall(w ErrorResponseWriter, r *http.Request) { + log := ctrllog.FromContext(r.Context()).WithName("sessions-handler").WithValues("operation", "find-by-function-call") + + functionCallID := r.URL.Query().Get("function_call_id") + if functionCallID == "" { + w.RespondWithError(errors.NewBadRequestError("function_call_id query parameter is required", nil)) + return + } + log = log.WithValues("function_call_id", functionCallID) + + log.V(1).Info("Finding session by parent function call ID") + session, err := h.DatabaseService.FindSessionByParentFunctionCallID(functionCallID) + if err != nil { + if stdErrors.Is(err, gorm.ErrRecordNotFound) { + w.RespondWithError(errors.NewNotFoundError("Session not found for function call ID", err)) + return + } + w.RespondWithError(errors.NewInternalServerError("Failed to find session by function call ID", err)) + return + } + + log.Info("Successfully found session", "sessionID", session.ID) + data := api.NewResponse(session, "Successfully found session", false) + RespondWithJSON(w, http.StatusOK, data) +} + func getUserID(r *http.Request) (string, error) { log := ctrllog.Log.WithName("http-helpers") diff --git a/go/internal/httpserver/handlers/sessions_test.go b/go/internal/httpserver/handlers/sessions_test.go index 3f81a866e..a15aa8cbb 100644 --- a/go/internal/httpserver/handlers/sessions_test.go +++ b/go/internal/httpserver/handlers/sessions_test.go @@ -269,11 +269,9 @@ func TestSessionsHandler(t *testing.T) { userID := "test-user" sessionID := "test-session" - // Create test session agentID := "1" createTestSession(dbClient, sessionID, userID, agentID) - // Create events with different timestamps event1 := &database.Event{ ID: "event-1", SessionID: sessionID, @@ -522,4 +520,58 @@ func TestSessionsHandler(t *testing.T) { assert.NotNil(t, responseRecorder.errorReceived) }) }) + + t.Run("HandleFindSessionByFunctionCall", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { + handler, dbClient, responseRecorder := setupHandler() + userID := "test-user" + sessionID := "test-session" + functionCallID := "call-123" + + agentID := "1" + session := createTestSession(dbClient, sessionID, userID, agentID) + + event := &database.Event{ + ID: "event-1", + SessionID: sessionID, + UserID: userID, + Data: `{"parent_function_call_id": "` + functionCallID + `"}`, + } + dbClient.StoreEvents(event) + + req := httptest.NewRequest("GET", "/api/sessions/find-by-function-call?function_call_id="+functionCallID, nil) + + handler.HandleFindSessionByFunctionCall(responseRecorder, req) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + var response api.StandardResponse[*database.Session] + err := json.Unmarshal(responseRecorder.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, session.ID, response.Data.ID) + }) + + t.Run("SessionNotFound", func(t *testing.T) { + handler, _, responseRecorder := setupHandler() + functionCallID := "non-existent-call" + + req := httptest.NewRequest("GET", "/api/sessions/find-by-function-call?function_call_id="+functionCallID, nil) + + handler.HandleFindSessionByFunctionCall(responseRecorder, req) + + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) + assert.NotNil(t, responseRecorder.errorReceived) + }) + + t.Run("MissingFunctionCallID", func(t *testing.T) { + handler, _, responseRecorder := setupHandler() + + req := httptest.NewRequest("GET", "/api/sessions/find-by-function-call", nil) + + handler.HandleFindSessionByFunctionCall(responseRecorder, req) + + assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) + assert.NotNil(t, responseRecorder.errorReceived) + }) + }) } diff --git a/go/internal/httpserver/server.go b/go/internal/httpserver/server.go index dbb64ae17..72192493c 100644 --- a/go/internal/httpserver/server.go +++ b/go/internal/httpserver/server.go @@ -167,6 +167,7 @@ func (s *HTTPServer) setupRoutes() { s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleListSessions)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleCreateSession)).Methods(http.MethodPost) s.router.HandleFunc(APIPathSessions+"/agent/{namespace}/{name}", adaptHandler(s.handlers.Sessions.HandleGetSessionsForAgent)).Methods(http.MethodGet) + s.router.HandleFunc(APIPathSessions+"/find-by-function-call", adaptHandler(s.handlers.Sessions.HandleFindSessionByFunctionCall)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleGetSession)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions+"/{session_id}/tasks", adaptHandler(s.handlers.Sessions.HandleListTasksForSession)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleDeleteSession)).Methods(http.MethodDelete) diff --git a/go/pkg/database/client.go b/go/pkg/database/client.go index e2ad77e0b..ad1e62dce 100644 --- a/go/pkg/database/client.go +++ b/go/pkg/database/client.go @@ -55,6 +55,9 @@ type Client interface { ListEventsForSession(sessionID, userID string, options QueryOptions) ([]*Event, error) ListPushNotifications(taskID string) ([]*protocol.TaskPushNotificationConfig, error) + // Lookup methods + FindSessionByParentFunctionCallID(functionCallID string) (*Session, error) + // Helper methods RefreshToolsForServer(serverName string, groupKind string, tools ...*v1alpha2.MCPTool) error diff --git a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py index 558a6833c..8136feaee 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py +++ b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py @@ -33,6 +33,7 @@ from typing_extensions import override from kagent.core.a2a import TaskResultAggregator, get_kagent_metadata_key + from kagent.core.tracing._span_processor import ( clear_kagent_span_attributes, set_kagent_span_attributes, @@ -262,11 +263,14 @@ async def _handle_request( # ensure the session exists session = await self._prepare_session(context, run_args, runner) - # set request headers to session state + # set request headers and A2A request metadata to session state headers = context.call_context.state.get("headers", {}) - state_changes = { + state_changes: dict[str, Any] = { "headers": headers, } + request_metadata = context.metadata + if request_metadata: + state_changes["a2a_request_metadata"] = request_metadata actions_with_update = EventActions(state_delta=state_changes) system_event = Event( diff --git a/python/packages/kagent-adk/src/kagent/adk/_sub_agent_session_plugin.py b/python/packages/kagent-adk/src/kagent/adk/_sub_agent_session_plugin.py new file mode 100644 index 000000000..c293fef45 --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/_sub_agent_session_plugin.py @@ -0,0 +1,141 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugin that captures A2A sub-agent session metadata. + +When a parent agent delegates to a remote sub-agent via AgentTool, this plugin: +1. Detects the sub-agent's context_id from A2A event metadata (on_event_callback) +2. Embeds the context_id in the tool result for historical/stored access (after_tool_callback) + +This plugin is automatically propagated to child runners via AgentTool's +include_plugins=True default, so on_event_callback fires on the child runner's +events where A2A metadata is present. +""" + +from __future__ import annotations + +import contextvars +import logging +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING + +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools.agent_tool import AgentTool + +if TYPE_CHECKING: + from google.adk.agents.invocation_context import InvocationContext + from google.adk.events.event import Event + from google.adk.tools.base_tool import BaseTool + from google.adk.tools.tool_context import ToolContext + +logger = logging.getLogger("kagent_adk." + __name__) + + +@dataclass +class _ToolCallState: + agent_tool_name: str | None = None + function_call_id: str | None = None + captured_context_id: str | None = None + captured_task_id: str | None = None + + +_current_tool_call: contextvars.ContextVar[_ToolCallState | None] = contextvars.ContextVar( + "_current_tool_call", default=None +) + + +def get_current_function_call_id() -> str | None: + """Return the function_call_id of the currently executing AgentTool call. + + This is used by the a2a_request_meta_provider to inject the parent's + function_call_id into outgoing A2A requests so the sub-agent can store it. + """ + tc = _current_tool_call.get(None) + return tc.function_call_id if tc else None + + +class SubAgentSessionPlugin(BasePlugin): + """Captures A2A sub-agent session context_id and embeds it in tool results.""" + + def __init__(self): + super().__init__(name="sub_agent_session") + + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> Optional[dict]: + if isinstance(tool, AgentTool): + _current_tool_call.set( + _ToolCallState( + agent_tool_name=tool.agent.name if hasattr(tool, "agent") else tool.name, + function_call_id=tool_context.function_call_id, + ) + ) + return None + + async def on_event_callback( + self, + *, + invocation_context: InvocationContext, + event: Event, + ) -> Optional[Event]: + if not event.custom_metadata: + return None + + tc = _current_tool_call.get(None) + if tc is None: + return None + + context_id = event.custom_metadata.get("a2a:context_id") + task_id = event.custom_metadata.get("a2a:task_id") + + if not context_id and not task_id: + return None + + if context_id and not tc.captured_context_id: + tc.captured_context_id = context_id + if task_id and not tc.captured_task_id: + tc.captured_task_id = task_id + + return None + + async def after_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + result: dict, + ) -> Optional[dict]: + if not isinstance(tool, AgentTool): + return None + + tc = _current_tool_call.get(None) + if tc is None or (not tc.captured_context_id and not tc.captured_task_id): + return None + + if isinstance(result, str): + result = {"result": result} + + if isinstance(result, dict): + if tc.captured_context_id: + result["a2a:context_id"] = tc.captured_context_id + if tc.captured_task_id: + result["a2a:task_id"] = tc.captured_task_id + return result + + return None diff --git a/python/packages/kagent-adk/src/kagent/adk/cli.py b/python/packages/kagent-adk/src/kagent/adk/cli.py index ff1b1b21a..5e2143f0e 100644 --- a/python/packages/kagent-adk/src/kagent/adk/cli.py +++ b/python/packages/kagent-adk/src/kagent/adk/cli.py @@ -15,6 +15,7 @@ from kagent.core import KAgentConfig, configure_logging, configure_tracing from . import AgentConfig, KAgentApp +from ._sub_agent_session_plugin import SubAgentSessionPlugin from .skill_fetcher import fetch_skill from .tools import add_skills_tool_to_agent @@ -61,16 +62,14 @@ def static( with open(os.path.join(filepath, "agent-card.json"), "r") as f: agent_card = json.load(f) agent_card = AgentCard.model_validate(agent_card) - plugins = None + plugins = [SubAgentSessionPlugin()] sts_integration = create_sts_integration() if sts_integration: - plugins = [sts_integration] + plugins.append(sts_integration) if agent_config.model.api_key_passthrough: from ._llm_passthrough_plugin import LLMPassthroughPlugin - if plugins is None: - plugins = [] plugins.append(LLMPassthroughPlugin()) def root_agent_factory() -> BaseAgent: @@ -150,10 +149,10 @@ def run( ): app_cfg = KAgentConfig() - plugins = None + plugins = [SubAgentSessionPlugin()] sts_integration = create_sts_integration() if sts_integration: - plugins = [sts_integration] + plugins.append(sts_integration) agent_loader = AgentLoader(agents_dir=working_dir) @@ -219,10 +218,10 @@ def root_agent_factory() -> BaseAgent: async def test_agent(agent_config: AgentConfig, agent_card: AgentCard, task: str): app_cfg = KAgentConfig(url="http://fake-url.example.com", name="test-agent", namespace="kagent") - plugins = None + plugins = [SubAgentSessionPlugin()] sts_integration = create_sts_integration() if sts_integration: - plugins = [sts_integration] + plugins.append(sts_integration) def root_agent_factory() -> BaseAgent: root_agent = agent_config.to_agent(app_cfg.name, sts_integration) diff --git a/python/packages/kagent-adk/src/kagent/adk/converters/part_converter.py b/python/packages/kagent-adk/src/kagent/adk/converters/part_converter.py index 1b1fd8b8a..b144ed21e 100644 --- a/python/packages/kagent-adk/src/kagent/adk/converters/part_converter.py +++ b/python/packages/kagent-adk/src/kagent/adk/converters/part_converter.py @@ -164,14 +164,23 @@ def convert_genai_part_to_a2a_part( ) if part.function_response: + data = part.function_response.model_dump(by_alias=True, exclude_none=True) + metadata = { + get_kagent_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + } + # Extract embedded sub-agent A2A metadata (injected by SubAgentSessionPlugin) + response = data.get("response", {}) + if isinstance(response, dict): + sub_ctx_id = response.pop("a2a:context_id", None) + sub_task_id = response.pop("a2a:task_id", None) + if sub_ctx_id: + metadata["a2a:context_id"] = sub_ctx_id + if sub_task_id: + metadata["a2a:task_id"] = sub_task_id return a2a_types.Part( root=a2a_types.DataPart( - data=part.function_response.model_dump(by_alias=True, exclude_none=True), - metadata={ - get_kagent_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - }, + data=data, + metadata=metadata, ) ) diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index eb0562425..4bd9f6b0a 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field from kagent.adk._mcp_toolset import KAgentMcpToolset +from kagent.adk._sub_agent_session_plugin import get_current_function_call_id from kagent.adk.models._litellm import KAgentLiteLlm from kagent.adk.sandbox_code_executer import SandboxedLocalCodeExecutor @@ -27,6 +28,9 @@ # Proxy host header used for Gateway API routing when using a proxy PROXY_HOST_HEADER = "x-kagent-host" +# Key used to propagate parent function_call_id to sub-agents via A2A request metadata +PARENT_FUNCTION_CALL_ID_KEY = "parent_function_call_id" + # Key used to store headers in session state HEADERS_STATE_KEY = "headers" @@ -84,6 +88,13 @@ def header_provider(readonly_context: Optional[ReadonlyContext]) -> dict[str, st return header_provider +def _parent_context_meta_provider(_ctx: Any, _msg: Any) -> dict[str, Any] | None: + fc_id = get_current_function_call_id() + if fc_id: + return {PARENT_FUNCTION_CALL_ID_KEY: fc_id} + return None + + def _convert_ollama_options(options: dict[str, str] | None) -> dict[str, Any]: """Convert Ollama options from string values to their correct types. @@ -333,6 +344,7 @@ async def rewrite_url_to_proxy(request: httpx.Request) -> None: agent_card=f"{remote_agent.url}{AGENT_CARD_WELL_KNOWN_PATH}", description=remote_agent.description, httpx_client=client, + a2a_request_meta_provider=_parent_context_meta_provider, ) tools.append(AgentTool(agent=remote_a2a_agent)) diff --git a/python/packages/kagent-adk/tests/unittests/test_agent_tool.py b/python/packages/kagent-adk/tests/unittests/test_agent_tool.py new file mode 100644 index 000000000..841c536b7 --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_agent_tool.py @@ -0,0 +1,415 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock + +import pytest +from google.adk.tools.agent_tool import AgentTool +from google.genai import types as genai_types + +from kagent.adk._sub_agent_session_plugin import SubAgentSessionPlugin +from kagent.adk.converters.part_converter import convert_genai_part_to_a2a_part +from kagent.core.a2a import ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, + A2A_DATA_PART_METADATA_TYPE_KEY, + get_kagent_metadata_key, +) + + +class TestSubAgentSessionPluginCallbacks: + def _create_mock_event(self, custom_metadata=None): + event = Mock() + event.custom_metadata = custom_metadata + return event + + def _create_mock_invocation_context(self): + ctx = Mock() + ctx.invocation_id = "inv-123" + ctx.session = Mock() + ctx.session.id = "session-123" + return ctx + + def _create_mock_tool_context(self): + context = Mock() + context.function_call_id = "call-123" + return context + + def _create_agent_tool(self, agent_name="test_agent"): + agent = Mock() + agent.name = agent_name + agent.sub_agents = [] + return AgentTool(agent=agent) + + @pytest.mark.asyncio + async def test_on_event_captures_context_id(self): + plugin = SubAgentSessionPlugin() + event = self._create_mock_event( + custom_metadata={ + "a2a:context_id": "ctx-123", + "a2a:task_id": "task-456", + } + ) + result = await plugin.on_event_callback( + invocation_context=self._create_mock_invocation_context(), + event=event, + ) + assert result is None + + @pytest.mark.asyncio + async def test_on_event_ignores_events_without_metadata(self): + plugin = SubAgentSessionPlugin() + result = await plugin.on_event_callback( + invocation_context=self._create_mock_invocation_context(), + event=self._create_mock_event(custom_metadata=None), + ) + assert result is None + + @pytest.mark.asyncio + async def test_on_event_ignores_events_without_a2a_keys(self): + plugin = SubAgentSessionPlugin() + result = await plugin.on_event_callback( + invocation_context=self._create_mock_invocation_context(), + event=self._create_mock_event(custom_metadata={"other_key": "value"}), + ) + assert result is None + + @pytest.mark.asyncio + async def test_before_tool_resets_state_for_agent_tool(self): + plugin = SubAgentSessionPlugin() + result = await plugin.before_tool_callback( + tool=self._create_agent_tool(), + tool_args={}, + tool_context=self._create_mock_tool_context(), + ) + assert result is None + + @pytest.mark.asyncio + async def test_before_tool_ignores_non_agent_tools(self): + plugin = SubAgentSessionPlugin() + result = await plugin.before_tool_callback( + tool=Mock(spec=[]), + tool_args={}, + tool_context=self._create_mock_tool_context(), + ) + assert result is None + + @pytest.mark.asyncio + async def test_after_tool_embeds_metadata_in_string_result(self): + plugin = SubAgentSessionPlugin() + tool = self._create_agent_tool() + tool_context = self._create_mock_tool_context() + + await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tool_context) + await plugin.on_event_callback( + invocation_context=self._create_mock_invocation_context(), + event=self._create_mock_event( + custom_metadata={ + "a2a:context_id": "ctx-123", + "a2a:task_id": "task-456", + } + ), + ) + result = await plugin.after_tool_callback( + tool=tool, + tool_args={}, + tool_context=tool_context, + result="agent response", + ) + + assert isinstance(result, dict) + assert result["result"] == "agent response" + assert result["a2a:context_id"] == "ctx-123" + assert result["a2a:task_id"] == "task-456" + + @pytest.mark.asyncio + async def test_after_tool_embeds_metadata_in_dict_result(self): + plugin = SubAgentSessionPlugin() + tool = self._create_agent_tool() + tool_context = self._create_mock_tool_context() + + await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tool_context) + await plugin.on_event_callback( + invocation_context=self._create_mock_invocation_context(), + event=self._create_mock_event(custom_metadata={"a2a:context_id": "ctx-789"}), + ) + result = await plugin.after_tool_callback( + tool=tool, + tool_args={}, + tool_context=tool_context, + result={"key": "value", "other": 42}, + ) + + assert isinstance(result, dict) + assert result["key"] == "value" + assert result["other"] == 42 + assert result["a2a:context_id"] == "ctx-789" + + @pytest.mark.asyncio + async def test_after_tool_returns_none_when_no_metadata(self): + plugin = SubAgentSessionPlugin() + tool = self._create_agent_tool() + tool_context = self._create_mock_tool_context() + + await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tool_context) + result = await plugin.after_tool_callback( + tool=tool, + tool_args={}, + tool_context=tool_context, + result="response", + ) + assert result is None + + @pytest.mark.asyncio + async def test_after_tool_returns_none_for_non_agent_tool(self): + plugin = SubAgentSessionPlugin() + result = await plugin.after_tool_callback( + tool=Mock(spec=[]), + tool_args={}, + tool_context=self._create_mock_tool_context(), + result="response", + ) + assert result is None + + @pytest.mark.asyncio + async def test_captures_from_first_event_only(self): + plugin = SubAgentSessionPlugin() + tool = self._create_agent_tool() + tool_context = self._create_mock_tool_context() + ctx = self._create_mock_invocation_context() + + await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tool_context) + await plugin.on_event_callback( + invocation_context=ctx, + event=self._create_mock_event( + custom_metadata={ + "a2a:context_id": "ctx-first", + "a2a:task_id": "task-first", + } + ), + ) + await plugin.on_event_callback( + invocation_context=ctx, + event=self._create_mock_event( + custom_metadata={ + "a2a:context_id": "ctx-second", + "a2a:task_id": "task-second", + } + ), + ) + result = await plugin.after_tool_callback( + tool=tool, + tool_args={}, + tool_context=tool_context, + result="response", + ) + + assert result["a2a:context_id"] == "ctx-first" + assert result["a2a:task_id"] == "task-first" + + @pytest.mark.asyncio + async def test_before_tool_resets_between_calls(self): + plugin = SubAgentSessionPlugin() + tool = self._create_agent_tool() + tool_context = self._create_mock_tool_context() + ctx = self._create_mock_invocation_context() + + await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tool_context) + await plugin.on_event_callback( + invocation_context=ctx, + event=self._create_mock_event(custom_metadata={"a2a:context_id": "ctx-first"}), + ) + result1 = await plugin.after_tool_callback( + tool=tool, + tool_args={}, + tool_context=tool_context, + result="r1", + ) + assert result1["a2a:context_id"] == "ctx-first" + + await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tool_context) + result2 = await plugin.after_tool_callback( + tool=tool, + tool_args={}, + tool_context=tool_context, + result="r2", + ) + assert result2 is None + + +class TestPartConverterMetadataExtraction: + """Test cases for convert_genai_part_to_a2a_part() metadata extraction.""" + + def test_metadata_extracted_from_function_response(self): + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={ + "result": "ok", + "a2a:context_id": "ctx-123", + "a2a:task_id": "task-456", + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + assert a2a_part.root.metadata is not None + assert a2a_part.root.metadata["a2a:context_id"] == "ctx-123" + assert a2a_part.root.metadata["a2a:task_id"] == "task-456" + + assert "a2a:context_id" not in a2a_part.root.data["response"] + assert "a2a:task_id" not in a2a_part.root.data["response"] + assert a2a_part.root.data["response"]["result"] == "ok" + + def test_metadata_only_context_id_extracted(self): + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={ + "result": "ok", + "a2a:context_id": "ctx-only", + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + assert a2a_part.root.metadata["a2a:context_id"] == "ctx-only" + assert "a2a:task_id" not in a2a_part.root.metadata + assert "a2a:context_id" not in a2a_part.root.data["response"] + + def test_metadata_only_task_id_extracted(self): + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={ + "result": "ok", + "a2a:task_id": "task-only", + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + assert a2a_part.root.metadata["a2a:task_id"] == "task-only" + assert "a2a:context_id" not in a2a_part.root.metadata + assert "a2a:task_id" not in a2a_part.root.data["response"] + + def test_no_metadata_when_not_present(self): + """Test that DataPart metadata only has type key when no embedded metadata.""" + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={"result": "ok"}, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + assert a2a_part.root.metadata is not None + # Should only have the type key + assert get_kagent_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) in a2a_part.root.metadata + assert ( + a2a_part.root.metadata[get_kagent_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + assert "a2a:context_id" not in a2a_part.root.metadata + assert "a2a:task_id" not in a2a_part.root.metadata + + def test_metadata_keys_cleaned_from_response_dict(self): + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={ + "result": "ok", + "other_key": "other_value", + "a2a:context_id": "ctx-123", + "a2a:task_id": "task-456", + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + response_data = a2a_part.root.data["response"] + + assert "a2a:context_id" not in response_data + assert "a2a:task_id" not in response_data + assert response_data["result"] == "ok" + assert response_data["other_key"] == "other_value" + + def test_metadata_extraction_with_nested_response_dict(self): + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={ + "nested": { + "result": "ok", + }, + "a2a:context_id": "ctx-nested", + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + assert a2a_part.root.data["response"]["nested"]["result"] == "ok" + assert "a2a:context_id" not in a2a_part.root.data["response"] + assert a2a_part.root.metadata["a2a:context_id"] == "ctx-nested" + + def test_metadata_extraction_with_empty_response_dict(self): + """Test that extraction handles empty response dict.""" + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={}, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + assert a2a_part.root.data["response"] == {} + assert get_kagent_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) in a2a_part.root.metadata + assert "a2a:context_id" not in a2a_part.root.metadata + assert "a2a:task_id" not in a2a_part.root.metadata + + def test_metadata_extraction_preserves_other_metadata(self): + function_response = genai_types.FunctionResponse( + name="test_tool", + id="call-1", + response={ + "result": "ok", + "a2a:context_id": "ctx-123", + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_part = convert_genai_part_to_a2a_part(part) + + assert a2a_part is not None + metadata = a2a_part.root.metadata + + assert get_kagent_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) in metadata + assert ( + metadata[get_kagent_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + assert metadata["a2a:context_id"] == "ctx-123" diff --git a/ui/src/app/actions/sessions.ts b/ui/src/app/actions/sessions.ts index 22efe8792..b7c5cb705 100644 --- a/ui/src/app/actions/sessions.ts +++ b/ui/src/app/actions/sessions.ts @@ -109,6 +109,17 @@ export async function checkSessionExists(sessionId: string): Promise> { + try { + const data = await fetchApi>( + `/sessions/find-by-function-call?function_call_id=${encodeURIComponent(functionCallId)}` + ); + return { message: "Session found successfully", data: data.data }; + } catch (error) { + return createErrorResponse(error, "Error finding session by function call ID"); + } +} + /** * Updates a session * @param session The session to update diff --git a/ui/src/app/agents/[namespace]/[name]/function-calls/[id]/page.tsx b/ui/src/app/agents/[namespace]/[name]/function-calls/[id]/page.tsx new file mode 100644 index 000000000..6abddecd9 --- /dev/null +++ b/ui/src/app/agents/[namespace]/[name]/function-calls/[id]/page.tsx @@ -0,0 +1,76 @@ +"use client"; + +import { use, useEffect, useState } from "react"; +import ChatInterface from "@/components/chat/ChatInterface"; +import { findSessionByFunctionCallId } from "@/app/actions/sessions"; +import { Loader2 } from "lucide-react"; +import Link from "next/link"; + +export default function FunctionCallPage({ + params, +}: { + params: Promise<{ namespace: string; name: string; id: string }>; +}) { + const { namespace, name, id } = use(params); + const [sessionId, setSessionId] = useState(null); + const [notFound, setNotFound] = useState(false); + const [loading, setLoading] = useState(true); + + useEffect(() => { + let cancelled = false; + async function lookup() { + const result = await findSessionByFunctionCallId(id); + if (cancelled) return; + if (result.data?.id) { + setSessionId(result.data.id); + } else { + setNotFound(true); + } + setLoading(false); + } + lookup(); + return () => { cancelled = true; }; + }, [id]); + + if (loading) { + return ( +
+ +

Looking up sub-agent session…

+
+ ); + } + + if (notFound) { + return ( +
+

Sub-agent session not available yet

+

+ The sub-agent may still be starting up. Try refreshing. +

+
+ + Retry + + + Back to chat + +
+
+ ); + } + + return ( + + ); +} diff --git a/ui/src/app/agents/[namespace]/[name]/chat/layout.tsx b/ui/src/app/agents/[namespace]/[name]/layout.tsx similarity index 100% rename from ui/src/app/agents/[namespace]/[name]/chat/layout.tsx rename to ui/src/app/agents/[namespace]/[name]/layout.tsx diff --git a/ui/src/components/chat/AgentCallDisplay.tsx b/ui/src/components/chat/AgentCallDisplay.tsx index 6ce80d14c..d734353fc 100644 --- a/ui/src/components/chat/AgentCallDisplay.tsx +++ b/ui/src/components/chat/AgentCallDisplay.tsx @@ -1,4 +1,5 @@ import { useMemo, useState } from "react"; +import Link from "next/link"; import { FunctionCall } from "@/types"; import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card"; import { convertToUserFriendlyName } from "@/lib/utils"; @@ -17,6 +18,8 @@ interface AgentCallDisplayProps { isError?: boolean; } +const AGENT_TOOL_NAME_RE = /^(.+)__NS__(.+)$/; + const AgentCallDisplay = ({ call, result, status = "requested", isError = false }: AgentCallDisplayProps) => { const [areInputsExpanded, setAreInputsExpanded] = useState(false); const [areResultsExpanded, setAreResultsExpanded] = useState(false); @@ -24,6 +27,11 @@ const AgentCallDisplay = ({ call, result, status = "requested", isError = false const agentDisplay = useMemo(() => convertToUserFriendlyName(call.name), [call.name]); const hasResult = result !== undefined; + const agentMatch = call.name.match(AGENT_TOOL_NAME_RE); + const functionCallLink = agentMatch + ? `/agents/${agentMatch[1].replace(/_/g, "-")}/${agentMatch[2].replace(/_/g, "-")}/function-calls/${call.id}` + : null; + const getStatusDisplay = () => { if (isError && status === "executing") { return ( @@ -76,7 +84,15 @@ const AgentCallDisplay = ({ call, result, status = "requested", isError = false {agentDisplay} -
{call.id}
+
+ {functionCallLink ? ( + + {call.id} + + ) : ( + call.id + )} +
{getStatusDisplay()}