Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions packages/types/src/__tests__/mode.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { modeConfigSchema } from "../mode.js"

describe("modeConfigSchema", () => {
const validBase = {
slug: "test-mode",
name: "Test Mode",
roleDefinition: "A test mode",
groups: ["read", "edit"],
}

it("should accept a mode config without allowedMcpServers", () => {
const result = modeConfigSchema.safeParse(validBase)
expect(result.success).toBe(true)
})

it("should accept a mode config with allowedMcpServers as an array of strings", () => {
const result = modeConfigSchema.safeParse({
...validBase,
groups: ["read", "edit", "mcp"],
allowedMcpServers: ["postgres-mcp", "redis-mcp"],
})
expect(result.success).toBe(true)
if (result.success) {
expect(result.data.allowedMcpServers).toEqual(["postgres-mcp", "redis-mcp"])
}
})

it("should accept a mode config with empty allowedMcpServers array", () => {
const result = modeConfigSchema.safeParse({
...validBase,
allowedMcpServers: [],
})
expect(result.success).toBe(true)
if (result.success) {
expect(result.data.allowedMcpServers).toEqual([])
}
})

it("should reject allowedMcpServers with non-string values", () => {
const result = modeConfigSchema.safeParse({
...validBase,
allowedMcpServers: [123, true],
})
expect(result.success).toBe(false)
})

it("should accept allowedMcpServers as undefined (backward compatible)", () => {
const result = modeConfigSchema.safeParse({
...validBase,
allowedMcpServers: undefined,
})
expect(result.success).toBe(true)
if (result.success) {
expect(result.data.allowedMcpServers).toBeUndefined()
}
})
})
6 changes: 6 additions & 0 deletions packages/types/src/mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ export const modeConfigSchema = z.object({
customInstructions: z.string().optional(),
groups: groupEntryArraySchema,
source: z.enum(["global", "project"]).optional(),
/**
* Optional allowlist of MCP server names for this mode.
* When defined, only the listed MCP servers' tools are injected for the mode.
* When omitted or empty, all enabled MCP servers are included (default behavior).
*/
allowedMcpServers: z.array(z.string()).optional(),
})

export type ModeConfig = z.infer<typeof modeConfigSchema>
Expand Down
9 changes: 8 additions & 1 deletion src/core/prompts/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ async function generatePrompt(

// Check if MCP functionality should be included
const hasMcpGroup = modeConfig.groups.some((groupEntry) => getGroupName(groupEntry) === "mcp")
const hasMcpServers = mcpHub && mcpHub.getServers().length > 0
let hasMcpServers = mcpHub && mcpHub.getServers().length > 0

// If this mode has an allowedMcpServers allowlist, check that at least one allowed server exists
if (hasMcpServers && modeConfig.allowedMcpServers && modeConfig.allowedMcpServers.length > 0) {
const allowedSet = new Set(modeConfig.allowedMcpServers)
hasMcpServers = mcpHub!.getServers().some((server) => allowedSet.has(server.name))
}

const shouldIncludeMcp = hasMcpGroup && hasMcpServers

const codeIndexManager = CodeIndexManager.getInstance(context, cwd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,66 @@ describe("getMcpServerTools", () => {
})
})

describe("allowedMcpServers filtering", () => {
it("should return all server tools when allowedMcpServers is undefined", () => {
const server1 = createMockServer("postgres-mcp", [createMockTool("query")])
const server2 = createMockServer("redis-mcp", [createMockTool("get")])
const server3 = createMockServer("filesystem-mcp", [createMockTool("read")])
const mockHub = createMockMcpHub([server1, server2, server3])

const result = getMcpServerTools(mockHub as McpHub, undefined)

expect(result).toHaveLength(3)
})

it("should return all server tools when allowedMcpServers is empty array", () => {
const server1 = createMockServer("postgres-mcp", [createMockTool("query")])
const server2 = createMockServer("redis-mcp", [createMockTool("get")])
const mockHub = createMockMcpHub([server1, server2])

const result = getMcpServerTools(mockHub as McpHub, [])

expect(result).toHaveLength(2)
})

it("should filter to only allowed servers", () => {
const server1 = createMockServer("postgres-mcp", [createMockTool("query")])
const server2 = createMockServer("redis-mcp", [createMockTool("get")])
const server3 = createMockServer("filesystem-mcp", [createMockTool("read")])
const mockHub = createMockMcpHub([server1, server2, server3])

const result = getMcpServerTools(mockHub as McpHub, ["postgres-mcp", "redis-mcp"])

expect(result).toHaveLength(2)
const toolNames = result.map((t) => getFunction(t).name)
expect(toolNames).toContain("mcp--postgres-mcp--query")
expect(toolNames).toContain("mcp--redis-mcp--get")
expect(toolNames).not.toContain("mcp--filesystem-mcp--read")
})

it("should return empty array when no servers match the allowlist", () => {
const server1 = createMockServer("postgres-mcp", [createMockTool("query")])
const mockHub = createMockMcpHub([server1])

const result = getMcpServerTools(mockHub as McpHub, ["nonexistent-server"])

expect(result).toEqual([])
})

it("should handle allowedMcpServers with a single server", () => {
const server1 = createMockServer("postgres-mcp", [createMockTool("query"), createMockTool("execute")])
const server2 = createMockServer("redis-mcp", [createMockTool("get")])
const mockHub = createMockMcpHub([server1, server2])

const result = getMcpServerTools(mockHub as McpHub, ["postgres-mcp"])

expect(result).toHaveLength(2)
const toolNames = result.map((t) => getFunction(t).name)
expect(toolNames).toContain("mcp--postgres-mcp--query")
expect(toolNames).toContain("mcp--postgres-mcp--execute")
})
})

it("should not include required field when schema has no required fields", () => {
const toolWithoutRequired: McpTool = {
name: "toolWithoutRequired",
Expand Down
12 changes: 10 additions & 2 deletions src/core/prompts/tools/native-tools/mcp_server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,22 @@ import { normalizeToolSchema, type JsonSchema } from "../../../../utils/json-sch
* global and project configs, project servers take priority (handled by McpHub.getServers()).
*
* @param mcpHub The McpHub instance containing connected servers.
* @param allowedMcpServers Optional allowlist of server names. When provided, only servers
* whose name is in the list will have their tools included. When omitted, all servers are included.
* @returns An array of OpenAI.Chat.ChatCompletionTool definitions.
*/
export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTool[] {
export function getMcpServerTools(mcpHub?: McpHub, allowedMcpServers?: string[]): OpenAI.Chat.ChatCompletionTool[] {
if (!mcpHub) {
return []
}

const servers = mcpHub.getServers()
let servers = mcpHub.getServers()

// If an allowlist is provided and non-empty, filter to only allowed servers
if (allowedMcpServers && allowedMcpServers.length > 0) {
const allowedSet = new Set(allowedMcpServers)
servers = servers.filter((server) => allowedSet.has(server.name))
}
const tools: OpenAI.Chat.ChatCompletionTool[] = []
// Track seen tool names to prevent duplicates (e.g., when same server exists in both global and project configs)
const seenToolNames = new Set<string>()
Expand Down
10 changes: 8 additions & 2 deletions src/core/task/build-tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { customToolRegistry, formatNative } from "@roo-code/core"

import type { ClineProvider } from "../webview/ClineProvider"
import { getRooDirectoriesForCwd } from "../../services/roo-config/index.js"
import { getModeBySlug } from "../../shared/modes"
import { defaultModeSlug } from "../../shared/modes"

import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools"
import {
Expand Down Expand Up @@ -124,8 +126,12 @@ export async function buildNativeToolsArrayWithRestrictions(options: BuildToolsO
mcpHub,
)

// Filter MCP tools based on mode restrictions.
const mcpTools = getMcpServerTools(mcpHub)
// Resolve the current mode config to get per-mode settings like allowedMcpServers.
const modeSlug = mode ?? defaultModeSlug
const modeConfig = getModeBySlug(modeSlug, customModes)

// Filter MCP tools based on mode restrictions and per-mode allowedMcpServers.
const mcpTools = getMcpServerTools(mcpHub, modeConfig?.allowedMcpServers)
const filteredMcpTools = filterMcpToolsForMode(mcpTools, mode, customModes, experiments)

// Add custom tools if they are available and the experiment is enabled.
Expand Down
Loading