diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index f75a71571..25b87022e 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -113,6 +113,23 @@ func NewRemoteToolset(name, urlString, transport string, headers map[string]stri // because there is no live connection to monitor. var errServerUnavailable = errors.New("MCP server unavailable") +const ( + // mcpInitTimeout is the maximum time allowed for the MCP initialization + // handshake (connect + initialize). If a remote server accepts the TCP + // connection but never responds, this prevents the agent from hanging + // indefinitely. + mcpInitTimeout = 2 * time.Minute + + // mcpCallToolTimeout is the maximum time allowed for a single tool call. + // Tool calls may be long-running (e.g. code execution), so this is + // deliberately generous. + mcpCallToolTimeout = 10 * time.Minute + + // mcpListTimeout is the maximum time allowed for listing tools, prompts + // or fetching a single prompt from the MCP server. + mcpListTimeout = 1 * time.Minute +) + // Describe returns a short, user-visible description of this toolset instance. // It never includes secrets. func (ts *Toolset) Describe() string { @@ -176,6 +193,11 @@ func (ts *Toolset) doStart(ctx context.Context) error { // This is critical for OAuth flows where the toolset connection needs to remain alive after the initial HTTP request completes. ctx = context.WithoutCancel(ctx) + // Apply an initialization timeout so we don't hang forever if the + // remote server accepts the connection but never responds. + initCtx, cancel := context.WithTimeout(ctx, mcpInitTimeout) + defer cancel() + slog.Debug("Starting MCP toolset", "server", ts.logID) // Register notification handlers to invalidate caches when the server @@ -219,7 +241,7 @@ func (ts *Toolset) doStart(ctx context.Context) error { const maxRetries = 3 for attempt := 0; ; attempt++ { var err error - result, err = ts.mcpClient.Initialize(ctx, initRequest) + result, err = ts.mcpClient.Initialize(initCtx, initRequest) if err == nil { break } @@ -247,8 +269,8 @@ func (ts *Toolset) doStart(ctx context.Context) error { slog.Debug("MCP initialize failed to send initialized notification; retrying", "id", ts.logID, "attempt", attempt+1, "backoff_ms", backoff.Milliseconds()) select { case <-time.After(backoff): - case <-ctx.Done(): - return fmt.Errorf("failed to initialize MCP client: %w", ctx.Err()) + case <-initCtx.Done(): + return fmt.Errorf("failed to initialize MCP client: %w", initCtx.Err()) } } diff --git a/pkg/tools/mcp/session_client.go b/pkg/tools/mcp/session_client.go index e104a35d0..16bbf767e 100644 --- a/pkg/tools/mcp/session_client.go +++ b/pkg/tools/mcp/session_client.go @@ -94,7 +94,16 @@ func (c *sessionClient) Close(context.Context) error { func (c *sessionClient) ListTools(ctx context.Context, request *gomcp.ListToolsParams) iter.Seq2[*gomcp.Tool, error] { if s := c.getSession(); s != nil { - return s.Tools(ctx, request) + ctx, cancel := context.WithTimeout(ctx, mcpListTimeout) + // Wrap the iterator so the cancel fires after iteration completes. + return func(yield func(*gomcp.Tool, error) bool) { + defer cancel() + for t, err := range s.Tools(ctx, request) { + if !yield(t, err) { + return + } + } + } } return func(yield func(*gomcp.Tool, error) bool) { yield(nil, errors.New("session not initialized")) @@ -103,6 +112,8 @@ func (c *sessionClient) ListTools(ctx context.Context, request *gomcp.ListToolsP func (c *sessionClient) CallTool(ctx context.Context, request *gomcp.CallToolParams) (*gomcp.CallToolResult, error) { if s := c.getSession(); s != nil { + ctx, cancel := context.WithTimeout(ctx, mcpCallToolTimeout) + defer cancel() return s.CallTool(ctx, request) } return nil, errors.New("session not initialized") @@ -110,7 +121,15 @@ func (c *sessionClient) CallTool(ctx context.Context, request *gomcp.CallToolPar func (c *sessionClient) ListPrompts(ctx context.Context, request *gomcp.ListPromptsParams) iter.Seq2[*gomcp.Prompt, error] { if s := c.getSession(); s != nil { - return s.Prompts(ctx, request) + ctx, cancel := context.WithTimeout(ctx, mcpListTimeout) + return func(yield func(*gomcp.Prompt, error) bool) { + defer cancel() + for p, err := range s.Prompts(ctx, request) { + if !yield(p, err) { + return + } + } + } } return func(yield func(*gomcp.Prompt, error) bool) { yield(nil, errors.New("session not initialized")) @@ -119,6 +138,8 @@ func (c *sessionClient) ListPrompts(ctx context.Context, request *gomcp.ListProm func (c *sessionClient) GetPrompt(ctx context.Context, request *gomcp.GetPromptParams) (*gomcp.GetPromptResult, error) { if s := c.getSession(); s != nil { + ctx, cancel := context.WithTimeout(ctx, mcpListTimeout) + defer cancel() return s.GetPrompt(ctx, request) } return nil, errors.New("session not initialized")