diff --git a/packages/types/src/__tests__/mode.test.ts b/packages/types/src/__tests__/mode.test.ts new file mode 100644 index 00000000000..debec415be1 --- /dev/null +++ b/packages/types/src/__tests__/mode.test.ts @@ -0,0 +1,57 @@ +import { modeConfigSchema } from "../mode.js" + +describe("modeConfigSchema", () => { + const validBase = { + slug: "test-mode", + name: "Test Mode", + roleDefinition: "A test mode", + groups: ["read", "edit"], + } + + it("should accept a mode config without allowedMcpServers", () => { + const result = modeConfigSchema.safeParse(validBase) + expect(result.success).toBe(true) + }) + + it("should accept a mode config with allowedMcpServers as an array of strings", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + groups: ["read", "edit", "mcp"], + allowedMcpServers: ["postgres-mcp", "redis-mcp"], + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.allowedMcpServers).toEqual(["postgres-mcp", "redis-mcp"]) + } + }) + + it("should accept a mode config with empty allowedMcpServers array", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + allowedMcpServers: [], + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.allowedMcpServers).toEqual([]) + } + }) + + it("should reject allowedMcpServers with non-string values", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + allowedMcpServers: [123, true], + }) + expect(result.success).toBe(false) + }) + + it("should accept allowedMcpServers as undefined (backward compatible)", () => { + const result = modeConfigSchema.safeParse({ + ...validBase, + allowedMcpServers: undefined, + }) + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.allowedMcpServers).toBeUndefined() + } + }) +}) diff --git a/packages/types/src/mode.ts b/packages/types/src/mode.ts index f981ba7bf9a..a5b95586926 100644 --- a/packages/types/src/mode.ts +++ b/packages/types/src/mode.ts @@ -102,6 +102,12 @@ export const modeConfigSchema = z.object({ customInstructions: z.string().optional(), groups: groupEntryArraySchema, source: z.enum(["global", "project"]).optional(), + /** + * Optional allowlist of MCP server names for this mode. + * When defined, only the listed MCP servers' tools are injected for the mode. + * When omitted or empty, all enabled MCP servers are included (default behavior). + */ + allowedMcpServers: z.array(z.string()).optional(), }) export type ModeConfig = z.infer diff --git a/src/core/prompts/system.ts b/src/core/prompts/system.ts index 0d6071644a9..7f88c48c4eb 100644 --- a/src/core/prompts/system.ts +++ b/src/core/prompts/system.ts @@ -66,7 +66,14 @@ async function generatePrompt( // Check if MCP functionality should be included const hasMcpGroup = modeConfig.groups.some((groupEntry) => getGroupName(groupEntry) === "mcp") - const hasMcpServers = mcpHub && mcpHub.getServers().length > 0 + let hasMcpServers = mcpHub && mcpHub.getServers().length > 0 + + // If this mode has an allowedMcpServers allowlist, check that at least one allowed server exists + if (hasMcpServers && modeConfig.allowedMcpServers && modeConfig.allowedMcpServers.length > 0) { + const allowedSet = new Set(modeConfig.allowedMcpServers) + hasMcpServers = mcpHub!.getServers().some((server) => allowedSet.has(server.name)) + } + const shouldIncludeMcp = hasMcpGroup && hasMcpServers const codeIndexManager = CodeIndexManager.getInstance(context, cwd) diff --git a/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts b/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts index ddd7caaccf4..1cd10efe5e9 100644 --- a/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts +++ b/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts @@ -170,6 +170,66 @@ describe("getMcpServerTools", () => { }) }) + describe("allowedMcpServers filtering", () => { + it("should return all server tools when allowedMcpServers is undefined", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const server3 = createMockServer("filesystem-mcp", [createMockTool("read")]) + const mockHub = createMockMcpHub([server1, server2, server3]) + + const result = getMcpServerTools(mockHub as McpHub, undefined) + + expect(result).toHaveLength(3) + }) + + it("should return all server tools when allowedMcpServers is empty array", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const mockHub = createMockMcpHub([server1, server2]) + + const result = getMcpServerTools(mockHub as McpHub, []) + + expect(result).toHaveLength(2) + }) + + it("should filter to only allowed servers", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const server3 = createMockServer("filesystem-mcp", [createMockTool("read")]) + const mockHub = createMockMcpHub([server1, server2, server3]) + + const result = getMcpServerTools(mockHub as McpHub, ["postgres-mcp", "redis-mcp"]) + + expect(result).toHaveLength(2) + const toolNames = result.map((t) => getFunction(t).name) + expect(toolNames).toContain("mcp--postgres-mcp--query") + expect(toolNames).toContain("mcp--redis-mcp--get") + expect(toolNames).not.toContain("mcp--filesystem-mcp--read") + }) + + it("should return empty array when no servers match the allowlist", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query")]) + const mockHub = createMockMcpHub([server1]) + + const result = getMcpServerTools(mockHub as McpHub, ["nonexistent-server"]) + + expect(result).toEqual([]) + }) + + it("should handle allowedMcpServers with a single server", () => { + const server1 = createMockServer("postgres-mcp", [createMockTool("query"), createMockTool("execute")]) + const server2 = createMockServer("redis-mcp", [createMockTool("get")]) + const mockHub = createMockMcpHub([server1, server2]) + + const result = getMcpServerTools(mockHub as McpHub, ["postgres-mcp"]) + + expect(result).toHaveLength(2) + const toolNames = result.map((t) => getFunction(t).name) + expect(toolNames).toContain("mcp--postgres-mcp--query") + expect(toolNames).toContain("mcp--postgres-mcp--execute") + }) + }) + it("should not include required field when schema has no required fields", () => { const toolWithoutRequired: McpTool = { name: "toolWithoutRequired", diff --git a/src/core/prompts/tools/native-tools/mcp_server.ts b/src/core/prompts/tools/native-tools/mcp_server.ts index 3fbd1fbcf4a..14503ce002d 100644 --- a/src/core/prompts/tools/native-tools/mcp_server.ts +++ b/src/core/prompts/tools/native-tools/mcp_server.ts @@ -9,14 +9,22 @@ import { normalizeToolSchema, type JsonSchema } from "../../../../utils/json-sch * global and project configs, project servers take priority (handled by McpHub.getServers()). * * @param mcpHub The McpHub instance containing connected servers. + * @param allowedMcpServers Optional allowlist of server names. When provided, only servers + * whose name is in the list will have their tools included. When omitted, all servers are included. * @returns An array of OpenAI.Chat.ChatCompletionTool definitions. */ -export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTool[] { +export function getMcpServerTools(mcpHub?: McpHub, allowedMcpServers?: string[]): OpenAI.Chat.ChatCompletionTool[] { if (!mcpHub) { return [] } - const servers = mcpHub.getServers() + let servers = mcpHub.getServers() + + // If an allowlist is provided and non-empty, filter to only allowed servers + if (allowedMcpServers && allowedMcpServers.length > 0) { + const allowedSet = new Set(allowedMcpServers) + servers = servers.filter((server) => allowedSet.has(server.name)) + } const tools: OpenAI.Chat.ChatCompletionTool[] = [] // Track seen tool names to prevent duplicates (e.g., when same server exists in both global and project configs) const seenToolNames = new Set() diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index c32d8f6f9b2..18f2599ba95 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -7,6 +7,8 @@ import { customToolRegistry, formatNative } from "@roo-code/core" import type { ClineProvider } from "../webview/ClineProvider" import { getRooDirectoriesForCwd } from "../../services/roo-config/index.js" +import { getModeBySlug } from "../../shared/modes" +import { defaultModeSlug } from "../../shared/modes" import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools" import { @@ -124,8 +126,12 @@ export async function buildNativeToolsArrayWithRestrictions(options: BuildToolsO mcpHub, ) - // Filter MCP tools based on mode restrictions. - const mcpTools = getMcpServerTools(mcpHub) + // Resolve the current mode config to get per-mode settings like allowedMcpServers. + const modeSlug = mode ?? defaultModeSlug + const modeConfig = getModeBySlug(modeSlug, customModes) + + // Filter MCP tools based on mode restrictions and per-mode allowedMcpServers. + const mcpTools = getMcpServerTools(mcpHub, modeConfig?.allowedMcpServers) const filteredMcpTools = filterMcpToolsForMode(mcpTools, mode, customModes, experiments) // Add custom tools if they are available and the experiment is enabled.