Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions go/internal/database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,27 @@ func (c *clientImpl) DeleteCheckpoint(userID, threadID string) error {
})
}

func (c *clientImpl) FindSessionByParentFunctionCallID(functionCallID string) (*dbpkg.Session, error) {
var event dbpkg.Event
// Use a wildcard between the key and value to handle potential whitespace (e.g. ": " vs ":")
searchPattern := fmt.Sprintf("%%\"parent_function_call_id\"%%\"%s\"%%", functionCallID)
err := c.db.Where("data LIKE ?", searchPattern).First(&event).Error
Comment on lines +497 to +498
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function_call_id parameter is directly embedded into a SQL LIKE pattern without proper sanitization. While Go's database/sql package provides protection against SQL injection for parameterized queries, special characters in the function call ID (like %, _, etc.) could cause unexpected matching behavior in the LIKE pattern. Consider escaping these special LIKE characters or validating the input format.

Copilot uses AI. Check for mistakes.
Comment on lines +496 to +498
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string pattern matching approach for JSON data is fragile and could produce false positives if the function call ID appears in other parts of the JSON (e.g., in a string value or different field). The pattern searches for any occurrence of the ID in the Data field. Consider using JSON extraction functions provided by the database (e.g., PostgreSQL's ->> operator or MySQL's JSON_EXTRACT) for more robust and accurate JSON field matching.

Suggested change
// Use a wildcard between the key and value to handle potential whitespace (e.g. ": " vs ":")
searchPattern := fmt.Sprintf("%%\"parent_function_call_id\"%%\"%s\"%%", functionCallID)
err := c.db.Where("data LIKE ?", searchPattern).First(&event).Error
// Use JSON field extraction to precisely match the parent_function_call_id value
err := c.db.Where("data->>'parent_function_call_id' = ?", functionCallID).First(&event).Error

Copilot uses AI. Check for mistakes.
Comment on lines +496 to +498
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LIKE query pattern may have performance implications on large datasets as it performs a full table scan. Consider adding a database index on the Event.Data column if this functionality is used frequently, or consider storing parent_function_call_id in a separate indexed column for more efficient lookups.

Suggested change
// Use a wildcard between the key and value to handle potential whitespace (e.g. ": " vs ":")
searchPattern := fmt.Sprintf("%%\"parent_function_call_id\"%%\"%s\"%%", functionCallID)
err := c.db.Where("data LIKE ?", searchPattern).First(&event).Error
// Query the JSON field directly for the parent_function_call_id to avoid a LIKE scan.
err := c.db.Where("JSON_EXTRACT(data, '$.parent_function_call_id') = ?", functionCallID).First(&event).Error

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with copilot here, I think searching through JSON like this is pretty fragile. Can we instead add more keys to the events, or think about organizing the DB data differently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I def agree ... need some guidance on the right solution I guess.

if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, fmt.Errorf("failed to find event by parent function call ID: %w", err)
}

var session dbpkg.Session
err = c.db.Where("id = ?", event.SessionID).First(&session).Error
if err != nil {
return nil, fmt.Errorf("failed to find session for event: %w", err)
}

return &session, nil
}

// CrewAI methods

// StoreCrewAIMemory stores CrewAI agent memory
Expand Down
20 changes: 20 additions & 0 deletions go/internal/database/fake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,26 @@ func (c *InMemoryFakeClient) ListWrites(userID, threadID, checkpointNS, checkpoi
return writes[start:end], nil
}

func (c *InMemoryFakeClient) FindSessionByParentFunctionCallID(functionCallID string) (*database.Session, error) {
c.mu.RLock()
defer c.mu.RUnlock()

searchPattern1 := fmt.Sprintf(`"parent_function_call_id":"%s"`, functionCallID)
searchPattern2 := fmt.Sprintf(`"parent_function_call_id": "%s"`, functionCallID)

for _, event := range c.events {
if strings.Contains(event.Data, searchPattern1) || strings.Contains(event.Data, searchPattern2) {
for _, session := range c.sessions {
if session.ID == event.SessionID {
return session, nil
}
}
return nil, gorm.ErrRecordNotFound
Comment on lines +748 to +758
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fake client implementation uses simple string matching which has the same issue as the real implementation - it could match the function call ID anywhere in the JSON data, not just in the parent_function_call_id field. Additionally, the dual pattern matching (with and without space after colon) is a workaround that suggests the data format is inconsistent. Consider using proper JSON parsing for more reliable matching.

Suggested change
searchPattern1 := fmt.Sprintf(`"parent_function_call_id":"%s"`, functionCallID)
searchPattern2 := fmt.Sprintf(`"parent_function_call_id": "%s"`, functionCallID)
for _, event := range c.events {
if strings.Contains(event.Data, searchPattern1) || strings.Contains(event.Data, searchPattern2) {
for _, session := range c.sessions {
if session.ID == event.SessionID {
return session, nil
}
}
return nil, gorm.ErrRecordNotFound
for _, event := range c.events {
var payload map[string]interface{}
if err := json.Unmarshal([]byte(event.Data), &payload); err != nil {
// If the event data is not valid JSON, skip this event.
continue
}
if v, ok := payload["parent_function_call_id"]; ok {
if id, ok := v.(string); ok && id == functionCallID {
for _, session := range c.sessions {
if session.ID == event.SessionID {
return session, nil
}
}
return nil, gorm.ErrRecordNotFound
}

Copilot uses AI. Check for mistakes.
}
}
return nil, gorm.ErrRecordNotFound
}

// CrewAI methods

// StoreCrewAIMemory stores CrewAI agent memory
Expand Down
28 changes: 28 additions & 0 deletions go/internal/httpserver/handlers/sessions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handlers

import (
stdErrors "errors"
"fmt"
"net/http"
"strconv"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/kagent-dev/kagent/go/internal/utils"
"github.com/kagent-dev/kagent/go/pkg/client/api"
"github.com/kagent-dev/kagent/go/pkg/database"
"gorm.io/gorm"
ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
"trpc.group/trpc-go/trpc-a2a-go/protocol"
)
Expand Down Expand Up @@ -395,6 +397,32 @@ func (h *SessionsHandler) HandleAddEventToSession(w ErrorResponseWriter, r *http
RespondWithJSON(w, http.StatusCreated, data)
}

func (h *SessionsHandler) HandleFindSessionByFunctionCall(w ErrorResponseWriter, r *http.Request) {
log := ctrllog.FromContext(r.Context()).WithName("sessions-handler").WithValues("operation", "find-by-function-call")

functionCallID := r.URL.Query().Get("function_call_id")
if functionCallID == "" {
w.RespondWithError(errors.NewBadRequestError("function_call_id query parameter is required", nil))
return
}
log = log.WithValues("function_call_id", functionCallID)

log.V(1).Info("Finding session by parent function call ID")
session, err := h.DatabaseService.FindSessionByParentFunctionCallID(functionCallID)
if err != nil {
if stdErrors.Is(err, gorm.ErrRecordNotFound) {
w.RespondWithError(errors.NewNotFoundError("Session not found for function call ID", err))
return
}
w.RespondWithError(errors.NewInternalServerError("Failed to find session by function call ID", err))
return
}

log.Info("Successfully found session", "sessionID", session.ID)
data := api.NewResponse(session, "Successfully found session", false)
RespondWithJSON(w, http.StatusOK, data)
}

func getUserID(r *http.Request) (string, error) {
log := ctrllog.Log.WithName("http-helpers")

Expand Down
56 changes: 54 additions & 2 deletions go/internal/httpserver/handlers/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,9 @@ func TestSessionsHandler(t *testing.T) {
userID := "test-user"
sessionID := "test-session"

// Create test session
agentID := "1"
createTestSession(dbClient, sessionID, userID, agentID)

// Create events with different timestamps
event1 := &database.Event{
ID: "event-1",
SessionID: sessionID,
Expand Down Expand Up @@ -522,4 +520,58 @@ func TestSessionsHandler(t *testing.T) {
assert.NotNil(t, responseRecorder.errorReceived)
})
})

t.Run("HandleFindSessionByFunctionCall", func(t *testing.T) {
t.Run("Success", func(t *testing.T) {
handler, dbClient, responseRecorder := setupHandler()
userID := "test-user"
sessionID := "test-session"
functionCallID := "call-123"

agentID := "1"
session := createTestSession(dbClient, sessionID, userID, agentID)

event := &database.Event{
ID: "event-1",
SessionID: sessionID,
UserID: userID,
Data: `{"parent_function_call_id": "` + functionCallID + `"}`,
}
dbClient.StoreEvents(event)

req := httptest.NewRequest("GET", "/api/sessions/find-by-function-call?function_call_id="+functionCallID, nil)

handler.HandleFindSessionByFunctionCall(responseRecorder, req)

assert.Equal(t, http.StatusOK, responseRecorder.Code)

var response api.StandardResponse[*database.Session]
err := json.Unmarshal(responseRecorder.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, session.ID, response.Data.ID)
})

t.Run("SessionNotFound", func(t *testing.T) {
handler, _, responseRecorder := setupHandler()
functionCallID := "non-existent-call"

req := httptest.NewRequest("GET", "/api/sessions/find-by-function-call?function_call_id="+functionCallID, nil)

handler.HandleFindSessionByFunctionCall(responseRecorder, req)

assert.Equal(t, http.StatusNotFound, responseRecorder.Code)
assert.NotNil(t, responseRecorder.errorReceived)
})

t.Run("MissingFunctionCallID", func(t *testing.T) {
handler, _, responseRecorder := setupHandler()

req := httptest.NewRequest("GET", "/api/sessions/find-by-function-call", nil)

handler.HandleFindSessionByFunctionCall(responseRecorder, req)

assert.Equal(t, http.StatusBadRequest, responseRecorder.Code)
assert.NotNil(t, responseRecorder.errorReceived)
})
})
}
1 change: 1 addition & 0 deletions go/internal/httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func (s *HTTPServer) setupRoutes() {
s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleListSessions)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleCreateSession)).Methods(http.MethodPost)
s.router.HandleFunc(APIPathSessions+"/agent/{namespace}/{name}", adaptHandler(s.handlers.Sessions.HandleGetSessionsForAgent)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions+"/find-by-function-call", adaptHandler(s.handlers.Sessions.HandleFindSessionByFunctionCall)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleGetSession)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions+"/{session_id}/tasks", adaptHandler(s.handlers.Sessions.HandleListTasksForSession)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleDeleteSession)).Methods(http.MethodDelete)
Expand Down
3 changes: 3 additions & 0 deletions go/pkg/database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ type Client interface {
ListEventsForSession(sessionID, userID string, options QueryOptions) ([]*Event, error)
ListPushNotifications(taskID string) ([]*protocol.TaskPushNotificationConfig, error)

// Lookup methods
FindSessionByParentFunctionCallID(functionCallID string) (*Session, error)

// Helper methods
RefreshToolsForServer(serverName string, groupKind string, tools ...*v1alpha2.MCPTool) error

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing_extensions import override

from kagent.core.a2a import TaskResultAggregator, get_kagent_metadata_key

from kagent.core.tracing._span_processor import (
clear_kagent_span_attributes,
set_kagent_span_attributes,
Expand Down Expand Up @@ -262,11 +263,14 @@ async def _handle_request(
# ensure the session exists
session = await self._prepare_session(context, run_args, runner)

# set request headers to session state
# set request headers and A2A request metadata to session state
headers = context.call_context.state.get("headers", {})
state_changes = {
state_changes: dict[str, Any] = {
"headers": headers,
}
request_metadata = context.metadata
if request_metadata:
state_changes["a2a_request_metadata"] = request_metadata

actions_with_update = EventActions(state_delta=state_changes)
system_event = Event(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Plugin that captures A2A sub-agent session metadata.

When a parent agent delegates to a remote sub-agent via AgentTool, this plugin:
1. Detects the sub-agent's context_id from A2A event metadata (on_event_callback)
2. Embeds the context_id in the tool result for historical/stored access (after_tool_callback)

This plugin is automatically propagated to child runners via AgentTool's
include_plugins=True default, so on_event_callback fires on the child runner's
events where A2A metadata is present.
"""

from __future__ import annotations

import contextvars
import logging
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING

from google.adk.plugins.base_plugin import BasePlugin
from google.adk.tools.agent_tool import AgentTool

if TYPE_CHECKING:
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events.event import Event
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext

logger = logging.getLogger("kagent_adk." + __name__)


@dataclass
class _ToolCallState:
agent_tool_name: str | None = None
function_call_id: str | None = None
captured_context_id: str | None = None
captured_task_id: str | None = None


_current_tool_call: contextvars.ContextVar[_ToolCallState | None] = contextvars.ContextVar(
"_current_tool_call", default=None
)


def get_current_function_call_id() -> str | None:
"""Return the function_call_id of the currently executing AgentTool call.

This is used by the a2a_request_meta_provider to inject the parent's
function_call_id into outgoing A2A requests so the sub-agent can store it.
"""
tc = _current_tool_call.get(None)
return tc.function_call_id if tc else None


class SubAgentSessionPlugin(BasePlugin):
"""Captures A2A sub-agent session context_id and embeds it in tool results."""

def __init__(self):
super().__init__(name="sub_agent_session")

async def before_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
if isinstance(tool, AgentTool):
_current_tool_call.set(
_ToolCallState(
agent_tool_name=tool.agent.name if hasattr(tool, "agent") else tool.name,
function_call_id=tool_context.function_call_id,
)
)
return None

async def on_event_callback(
self,
*,
invocation_context: InvocationContext,
event: Event,
) -> Optional[Event]:
if not event.custom_metadata:
return None

tc = _current_tool_call.get(None)
if tc is None:
return None

context_id = event.custom_metadata.get("a2a:context_id")
task_id = event.custom_metadata.get("a2a:task_id")

if not context_id and not task_id:
return None

if context_id and not tc.captured_context_id:
tc.captured_context_id = context_id
if task_id and not tc.captured_task_id:
tc.captured_task_id = task_id

return None
Comment on lines +90 to +114
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contextvars mechanism relies on proper async context propagation, but there's no defensive handling if the context is lost or if before_tool_callback is not called before on_event_callback. If the callbacks are invoked out of order or if the context is not properly propagated in certain async scenarios, the plugin could fail silently. Consider adding logging or defensive checks to detect and handle these edge cases.

Copilot uses AI. Check for mistakes.

async def after_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
result: dict,
) -> Optional[dict]:
if not isinstance(tool, AgentTool):
return None

tc = _current_tool_call.get(None)
if tc is None or (not tc.captured_context_id and not tc.captured_task_id):
return None

if isinstance(result, str):
result = {"result": result}

if isinstance(result, dict):
if tc.captured_context_id:
result["a2a:context_id"] = tc.captured_context_id
if tc.captured_task_id:
result["a2a:task_id"] = tc.captured_task_id
return result
Comment on lines +134 to +139
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plugin modifies the result dictionary by adding a2a metadata keys, which could potentially conflict with existing keys in the result if a tool legitimately returns data with keys named "a2a:context_id" or "a2a:task_id". While these are namespaced keys unlikely to collide, consider documenting this behavior or adding a check to warn if these keys already exist in the result.

Copilot uses AI. Check for mistakes.

return None
Loading
Loading