Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions pkg/tools/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,23 @@ func NewRemoteToolset(name, urlString, transport string, headers map[string]stri
// because there is no live connection to monitor.
var errServerUnavailable = errors.New("MCP server unavailable")

const (
// mcpInitTimeout is the maximum time allowed for the MCP initialization
// handshake (connect + initialize). If a remote server accepts the TCP
// connection but never responds, this prevents the agent from hanging
// indefinitely.
mcpInitTimeout = 2 * time.Minute

// mcpCallToolTimeout is the maximum time allowed for a single tool call.
// Tool calls may be long-running (e.g. code execution), so this is
// deliberately generous.
mcpCallToolTimeout = 10 * time.Minute

// mcpListTimeout is the maximum time allowed for listing tools, prompts
// or fetching a single prompt from the MCP server.
mcpListTimeout = 1 * time.Minute
)

// Describe returns a short, user-visible description of this toolset instance.
// It never includes secrets.
func (ts *Toolset) Describe() string {
Expand Down Expand Up @@ -176,6 +193,11 @@ func (ts *Toolset) doStart(ctx context.Context) error {
// This is critical for OAuth flows where the toolset connection needs to remain alive after the initial HTTP request completes.
ctx = context.WithoutCancel(ctx)

// Apply an initialization timeout so we don't hang forever if the
// remote server accepts the connection but never responds.
initCtx, cancel := context.WithTimeout(ctx, mcpInitTimeout)
defer cancel()

slog.Debug("Starting MCP toolset", "server", ts.logID)

// Register notification handlers to invalidate caches when the server
Expand Down Expand Up @@ -219,7 +241,7 @@ func (ts *Toolset) doStart(ctx context.Context) error {
const maxRetries = 3
for attempt := 0; ; attempt++ {
var err error
result, err = ts.mcpClient.Initialize(ctx, initRequest)
result, err = ts.mcpClient.Initialize(initCtx, initRequest)
if err == nil {
break
}
Expand Down Expand Up @@ -247,8 +269,8 @@ func (ts *Toolset) doStart(ctx context.Context) error {
slog.Debug("MCP initialize failed to send initialized notification; retrying", "id", ts.logID, "attempt", attempt+1, "backoff_ms", backoff.Milliseconds())
select {
case <-time.After(backoff):
case <-ctx.Done():
return fmt.Errorf("failed to initialize MCP client: %w", ctx.Err())
case <-initCtx.Done():
return fmt.Errorf("failed to initialize MCP client: %w", initCtx.Err())
}
}

Expand Down
25 changes: 23 additions & 2 deletions pkg/tools/mcp/session_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,16 @@ func (c *sessionClient) Close(context.Context) error {

func (c *sessionClient) ListTools(ctx context.Context, request *gomcp.ListToolsParams) iter.Seq2[*gomcp.Tool, error] {
if s := c.getSession(); s != nil {
return s.Tools(ctx, request)
ctx, cancel := context.WithTimeout(ctx, mcpListTimeout)
// Wrap the iterator so the cancel fires after iteration completes.
return func(yield func(*gomcp.Tool, error) bool) {
defer cancel()
for t, err := range s.Tools(ctx, request) {
if !yield(t, err) {
return
}
}
}
}
return func(yield func(*gomcp.Tool, error) bool) {
yield(nil, errors.New("session not initialized"))
Expand All @@ -103,14 +112,24 @@ func (c *sessionClient) ListTools(ctx context.Context, request *gomcp.ListToolsP

func (c *sessionClient) CallTool(ctx context.Context, request *gomcp.CallToolParams) (*gomcp.CallToolResult, error) {
if s := c.getSession(); s != nil {
ctx, cancel := context.WithTimeout(ctx, mcpCallToolTimeout)
defer cancel()
return s.CallTool(ctx, request)
}
return nil, errors.New("session not initialized")
}

func (c *sessionClient) ListPrompts(ctx context.Context, request *gomcp.ListPromptsParams) iter.Seq2[*gomcp.Prompt, error] {
if s := c.getSession(); s != nil {
return s.Prompts(ctx, request)
ctx, cancel := context.WithTimeout(ctx, mcpListTimeout)
return func(yield func(*gomcp.Prompt, error) bool) {
defer cancel()
for p, err := range s.Prompts(ctx, request) {
if !yield(p, err) {
return
}
}
}
}
return func(yield func(*gomcp.Prompt, error) bool) {
yield(nil, errors.New("session not initialized"))
Expand All @@ -119,6 +138,8 @@ func (c *sessionClient) ListPrompts(ctx context.Context, request *gomcp.ListProm

func (c *sessionClient) GetPrompt(ctx context.Context, request *gomcp.GetPromptParams) (*gomcp.GetPromptResult, error) {
if s := c.getSession(); s != nil {
ctx, cancel := context.WithTimeout(ctx, mcpListTimeout)
defer cancel()
return s.GetPrompt(ctx, request)
}
return nil, errors.New("session not initialized")
Expand Down
Loading