From 27b4b30833c2ca3ee069e59a2800b8ac52591c4e Mon Sep 17 00:00:00 2001 From: Roo Code Date: Fri, 27 Mar 2026 17:55:03 +0000 Subject: [PATCH] feat: add per-mode MCP server allowlist (allowedMcpServers) Add an optional allowedMcpServers field to ModeConfig that acts as a whitelist for MCP servers on a per-mode basis. When defined, only the listed MCP servers tools are injected for that mode. When omitted or empty, all enabled MCP servers are included (preserving current behavior). This addresses context bloat and tool limit issues when running multiple MCP servers with models that have strict tool limits (e.g. 128-tool limit). Changes: - packages/types/src/mode.ts: Add allowedMcpServers to modeConfigSchema - src/core/prompts/tools/native-tools/mcp_server.ts: Accept allowedMcpServers filter - src/core/task/build-tools.ts: Pass allowedMcpServers from mode config - src/core/prompts/system.ts: Filter MCP capabilities in system prompt Closes #12004 --- packages/types/src/__tests__/mode.test.ts | 57 ++++++++++++++++++ packages/types/src/mode.ts | 6 ++ src/core/prompts/system.ts | 9 ++- .../native-tools/__tests__/mcp_server.spec.ts | 60 +++++++++++++++++++ .../prompts/tools/native-tools/mcp_server.ts | 12 +++- src/core/task/build-tools.ts | 10 +++- 6 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 packages/types/src/__tests__/mode.test.ts diff --git a/packages/types/src/__tests__/mode.test.ts b/packages/types/src/__tests__/mode.test.ts new file mode 100644 index 00000000000..debec415be1 --- /dev/null +++ b/packages/types/src/__tests__/mode.test.ts @@ -0,0 +1,57 @@ +import { modeConfigSchema } from "../mode.js" + +describe("modeConfigSchema", () => { + const validBase = { + slug: "test-mode", + name: "Test Mode", + roleDefinition: "A test mode", + groups: ["read", "edit"], + } + + it("should accept a mode config without allowedMcpServers", () => { + const result = modeConfigSchema.safeParse(validBase) + expect(result.success).toBe(true) + }) + + it("should accept a mode config with allowedMcpServers as an array of strings", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + groups: ["read", "edit", "mcp"], + allowedMcpServers: ["postgres-mcp", "redis-mcp"], + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.allowedMcpServers).toEqual(["postgres-mcp", "redis-mcp"]) + } + }) + + it("should accept a mode config with empty allowedMcpServers array", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + allowedMcpServers: [], + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.allowedMcpServers).toEqual([]) + } + }) + + it("should reject allowedMcpServers with non-string values", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + allowedMcpServers: [123, true], + }) + expect(result.success).toBe(false) + }) + + it("should accept allowedMcpServers as undefined (backward compatible)", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + allowedMcpServers: undefined, + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.allowedMcpServers).toBeUndefined() + } + }) +}) diff --git a/packages/types/src/mode.ts b/packages/types/src/mode.ts index f981ba7bf9a..a5b95586926 100644 --- a/packages/types/src/mode.ts +++ b/packages/types/src/mode.ts @@ -102,6 +102,12 @@ export const modeConfigSchema = z.object({ customInstructions: z.string().optional(), groups: groupEntryArraySchema, source: z.enum(["global", "project"]).optional(), + /** + * Optional allowlist of MCP server names for this mode. + * When defined, only the listed MCP servers' tools are injected for the mode. + * When omitted or empty, all enabled MCP servers are included (default behavior). + */ + allowedMcpServers: z.array(z.string()).optional(), }) export type ModeConfig = z.infer diff --git a/src/core/prompts/system.ts b/src/core/prompts/system.ts index 0d6071644a9..7f88c48c4eb 100644 --- a/src/core/prompts/system.ts +++ b/src/core/prompts/system.ts @@ -66,7 +66,14 @@ async function generatePrompt( // Check if MCP functionality should be included const hasMcpGroup = modeConfig.groups.some((groupEntry) => getGroupName(groupEntry) === "mcp") - const hasMcpServers = mcpHub && mcpHub.getServers().length > 0 + let hasMcpServers = mcpHub && mcpHub.getServers().length > 0 + + // If this mode has an allowedMcpServers allowlist, check that at least one allowed server exists + if (hasMcpServers && modeConfig.allowedMcpServers && modeConfig.allowedMcpServers.length > 0) { + const allowedSet = new Set(modeConfig.allowedMcpServers) + hasMcpServers = mcpHub!.getServers().some((server) => allowedSet.has(server.name)) + } + const shouldIncludeMcp = hasMcpGroup && hasMcpServers const codeIndexManager = CodeIndexManager.getInstance(context, cwd) diff --git a/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts b/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts index ddd7caaccf4..1cd10efe5e9 100644 --- a/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts +++ b/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts @@ -170,6 +170,66 @@ describe("getMcpServerTools", () => { }) }) + describe("allowedMcpServers filtering", () => { + it("should return all server tools when allowedMcpServers is undefined", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const server3 = createMockServer("filesystem-mcp", [createMockTool("read")]) + const mockHub = createMockMcpHub([server1, server2, server3]) + + const result = getMcpServerTools(mockHub as McpHub, undefined) + + expect(result).toHaveLength(3) + }) + + it("should return all server tools when allowedMcpServers is empty array", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const mockHub = createMockMcpHub([server1, server2]) + + const result = getMcpServerTools(mockHub as McpHub, []) + + expect(result).toHaveLength(2) + }) + + it("should filter to only allowed servers", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const server3 = createMockServer("filesystem-mcp", [createMockTool("read")]) + const mockHub = createMockMcpHub([server1, server2, server3]) + + const result = getMcpServerTools(mockHub as McpHub, ["postgres-mcp", "redis-mcp"]) + + expect(result).toHaveLength(2) + const toolNames = result.map((t) => getFunction(t).name) + expect(toolNames).toContain("mcp--postgres-mcp--query") + expect(toolNames).toContain("mcp--redis-mcp--get") + expect(toolNames).not.toContain("mcp--filesystem-mcp--read") + }) + + it("should return empty array when no servers match the allowlist", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const mockHub = createMockMcpHub([server1]) + + const result = getMcpServerTools(mockHub as McpHub, ["nonexistent-server"]) + + expect(result).toEqual([]) + }) + + it("should handle allowedMcpServers with a single server", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query"), createMockTool("execute")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const mockHub = createMockMcpHub([server1, server2]) + + const result = getMcpServerTools(mockHub as McpHub, ["postgres-mcp"]) + + expect(result).toHaveLength(2) + const toolNames = result.map((t) => getFunction(t).name) + expect(toolNames).toContain("mcp--postgres-mcp--query") + expect(toolNames).toContain("mcp--postgres-mcp--execute") + }) + }) + it("should not include required field when schema has no required fields", () => { const toolWithoutRequired: McpTool = { name: "toolWithoutRequired", diff --git a/src/core/prompts/tools/native-tools/mcp_server.ts b/src/core/prompts/tools/native-tools/mcp_server.ts index 3fbd1fbcf4a..14503ce002d 100644 --- a/src/core/prompts/tools/native-tools/mcp_server.ts +++ b/src/core/prompts/tools/native-tools/mcp_server.ts @@ -9,14 +9,22 @@ import { normalizeToolSchema, type JsonSchema } from "../../../../utils/json-sch * global and project configs, project servers take priority (handled by McpHub.getServers()). * * @param mcpHub The McpHub instance containing connected servers. + * @param allowedMcpServers Optional allowlist of server names. When provided, only servers + * whose name is in the list will have their tools included. When omitted, all servers are included. * @returns An array of OpenAI.Chat.ChatCompletionTool definitions. */ -export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTool[] { +export function getMcpServerTools(mcpHub?: McpHub, allowedMcpServers?: string[]): OpenAI.Chat.ChatCompletionTool[] { if (!mcpHub) { return [] } - const servers = mcpHub.getServers() + let servers = mcpHub.getServers() + + // If an allowlist is provided and non-empty, filter to only allowed servers + if (allowedMcpServers && allowedMcpServers.length > 0) { + const allowedSet = new Set(allowedMcpServers) + servers = servers.filter((server) => allowedSet.has(server.name)) + } const tools: OpenAI.Chat.ChatCompletionTool[] = [] // Track seen tool names to prevent duplicates (e.g., when same server exists in both global and project configs) const seenToolNames = new Set() diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index c32d8f6f9b2..18f2599ba95 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -7,6 +7,8 @@ import { customToolRegistry, formatNative } from "@roo-code/core" import type { ClineProvider } from "../webview/ClineProvider" import { getRooDirectoriesForCwd } from "../../services/roo-config/index.js" +import { getModeBySlug } from "../../shared/modes" +import { defaultModeSlug } from "../../shared/modes" import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools" import { @@ -124,8 +126,12 @@ export async function buildNativeToolsArrayWithRestrictions(options: BuildToolsO mcpHub, ) - // Filter MCP tools based on mode restrictions. - const mcpTools = getMcpServerTools(mcpHub) + // Resolve the current mode config to get per-mode settings like allowedMcpServers. + const modeSlug = mode ?? defaultModeSlug + const modeConfig = getModeBySlug(modeSlug, customModes) + + // Filter MCP tools based on mode restrictions and per-mode allowedMcpServers. + const mcpTools = getMcpServerTools(mcpHub, modeConfig?.allowedMcpServers) const filteredMcpTools = filterMcpToolsForMode(mcpTools, mode, customModes, experiments) // Add custom tools if they are available and the experiment is enabled.