From 73499a8f2ef7a0f32a0cf89e9aa6f973d66b976c Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 11 Jan 2026 19:36:23 +0000 Subject: [PATCH 1/5] fix: allow omitting mcpServers in session requests --- scripts/gen_schema.py | 33 +++++++-- scripts/schema_patches.py | 51 ++++++++++++++ src/acp/client/connection.py | 25 ++++++- src/acp/interfaces.py | 11 ++- src/acp/schema.py | 8 +-- tests/conftest.py | 11 ++- .../test_issue_55_mcp_servers_optional.py | 68 +++++++++++++++++++ 7 files changed, 191 insertions(+), 16 deletions(-) create mode 100644 scripts/schema_patches.py create mode 100644 tests/real_user/test_issue_55_mcp_servers_optional.py diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 74a8790..5a505a6 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -2,16 +2,23 @@ from __future__ import annotations import ast +import contextlib import json import re import subprocess import sys +import tempfile import textwrap from collections.abc import Callable from dataclasses import dataclass from pathlib import Path ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) + +from scripts.schema_patches import apply_schema_patches # noqa: E402 + SCHEMA_DIR = ROOT / "schema" SCHEMA_JSON = SCHEMA_DIR / "schema.json" VERSION_FILE = SCHEMA_DIR / "VERSION" @@ -136,12 +143,23 @@ def generate_schema() -> None: ) sys.exit(1) + schema_payload = json.loads(SCHEMA_JSON.read_text(encoding="utf-8")) + schema_payload, patch_warnings = apply_schema_patches(schema_payload) + for warning in patch_warnings: + print(f"Warning: {warning.message}", file=sys.stderr) + + patched_schema_path: Path | None = None + with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as handle: + json.dump(schema_payload, handle, indent=2) + handle.write("\n") + patched_schema_path = Path(handle.name) + cmd = [ sys.executable, "-m", "datamodel_code_generator", "--input", - str(SCHEMA_JSON), + str(patched_schema_path), "--input-file-type", "jsonschema", "--output", @@ -155,10 +173,15 @@ def generate_schema() -> None: "--snake-case-field", ] - subprocess.check_call(cmd) # noqa: S603 - warnings = postprocess_generated_schema(SCHEMA_OUT) - for warning in warnings: - print(f"Warning: {warning}", file=sys.stderr) + try: + subprocess.check_call(cmd) # noqa: S603 + warnings = postprocess_generated_schema(SCHEMA_OUT) + for warning in warnings: + print(f"Warning: {warning}", file=sys.stderr) + finally: + if patched_schema_path is not None: + with contextlib.suppress(OSError): + patched_schema_path.unlink() def postprocess_generated_schema(output_path: Path) -> list[str]: diff --git a/scripts/schema_patches.py b/scripts/schema_patches.py new file mode 100644 index 0000000..c2f828a --- /dev/null +++ b/scripts/schema_patches.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class PatchWarning: + message: str + + +def apply_schema_patches(schema: dict[str, Any]) -> tuple[dict[str, Any], list[PatchWarning]]: + patched = schema + warnings: list[PatchWarning] = [] + + patched, warning = _make_defs_field_optional(patched, "NewSessionRequest", "mcpServers") + if warning is not None: + warnings.append(warning) + + patched, warning = _make_defs_field_optional(patched, "LoadSessionRequest", "mcpServers") + if warning is not None: + warnings.append(warning) + + return patched, warnings + + +def _make_defs_field_optional( + schema: dict[str, Any], + model_name: str, + field_name: str, +) -> tuple[dict[str, Any], PatchWarning | None]: + defs = schema.get("$defs") + if not isinstance(defs, dict): + return schema, PatchWarning("schema.$defs missing or invalid; cannot apply patches") + + model = defs.get(model_name) + if not isinstance(model, dict): + return schema, PatchWarning(f"schema.$defs.{model_name} missing or invalid; cannot patch {field_name}") + + required = model.get("required") + if required is None: + return schema, None + if not isinstance(required, list): + return schema, PatchWarning(f"schema.$defs.{model_name}.required invalid; cannot patch {field_name}") + + new_required = [item for item in required if item != field_name] + if new_required == required: + return schema, None + + model["required"] = new_required + return schema, None diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index ac0d34f..dab450b 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -45,6 +45,7 @@ __all__ = ["ClientSideConnection"] _CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader" +_MISSING = object() @final @@ -93,7 +94,10 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + self, + cwd: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> NewSessionResponse: return await request_model( self._conn, @@ -104,12 +108,27 @@ async def new_session( @param_model(LoadSessionRequest) async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + self, + cwd: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | str | None = None, + session_id: str | object = _MISSING, + **kwargs: Any, ) -> LoadSessionResponse: + if session_id is _MISSING: + if isinstance(mcp_servers, str): + session_id = mcp_servers + mcp_servers = None + else: + raise TypeError("load_session() missing required argument: 'session_id'") return await request_model_from_dict( self._conn, AGENT_METHODS["session_load"], - LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None), + LoadSessionRequest( + cwd=cwd, + mcp_servers=cast(list[HttpMcpServer | SseMcpServer | McpServerStdio] | None, mcp_servers), + session_id=cast(str, session_id), + field_meta=kwargs or None, + ), LoadSessionResponse, ) diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 457dfe7..3bf8eb4 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -154,12 +154,19 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + self, + cwd: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> NewSessionResponse: ... @param_model(LoadSessionRequest) async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> LoadSessionResponse | None: ... @param_model(ListSessionsRequest) diff --git a/src/acp/schema.py b/src/acp/schema.py index e449e4a..b0247fe 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -1410,12 +1410,12 @@ class NewSessionRequest(BaseModel): ] # List of MCP (Model Context Protocol) servers the agent should connect to. mcp_servers: Annotated[ - List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]], + Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]], Field( alias="mcpServers", description="List of MCP (Model Context Protocol) servers the agent should connect to.", ), - ] + ] = None class PermissionOption(BaseModel): @@ -2073,12 +2073,12 @@ class LoadSessionRequest(BaseModel): cwd: Annotated[str, Field(description="The working directory for this session.")] # List of MCP servers to connect to for this session. mcp_servers: Annotated[ - List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]], + Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]], Field( alias="mcpServers", description="List of MCP servers to connect to for this session.", ), - ] + ] = None # The ID of the session to load. session_id: Annotated[str, Field(alias="sessionId", description="The ID of the session to load.")] diff --git a/tests/conftest.py b/tests/conftest.py index 6cce0b1..266fa71 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -243,12 +243,19 @@ async def initialize( return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[]) async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + self, + cwd: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> NewSessionResponse: return NewSessionResponse(session_id="test-session-123") async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> LoadSessionResponse | None: return LoadSessionResponse() diff --git a/tests/real_user/test_issue_55_mcp_servers_optional.py b/tests/real_user/test_issue_55_mcp_servers_optional.py new file mode 100644 index 0000000..52ea330 --- /dev/null +++ b/tests/real_user/test_issue_55_mcp_servers_optional.py @@ -0,0 +1,68 @@ +import asyncio +from typing import Any + +import pytest + +from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse +from acp.core import AgentSideConnection, ClientSideConnection +from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer +from tests.conftest import TestAgent, TestClient + +# Regression from a real-world client run where `mcpServers` is omitted from session requests. + + +class Issue55Agent(TestAgent): + def __init__(self) -> None: + super().__init__() + self.seen_new_session: tuple[str, Any] | None = None + self.seen_load_session: tuple[str, str, Any] | None = None + + async def new_session( + self, + cwd: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, + ) -> NewSessionResponse: + self.seen_new_session = (cwd, mcp_servers) + return await super().new_session(cwd=cwd, mcp_servers=mcp_servers, **kwargs) + + async def load_session( + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, + ) -> LoadSessionResponse | None: + self.seen_load_session = (cwd, session_id, mcp_servers) + return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=mcp_servers, **kwargs) + + +@pytest.mark.asyncio +async def test_session_requests_allow_missing_mcp_servers(server) -> None: + client = TestClient() + captured_agent: list[Issue55Agent] = [] + + agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type] + _agent_side = AgentSideConnection( + lambda _conn: captured_agent.append(Issue55Agent()) or captured_agent[-1], + server._server_writer, + server._server_reader, + listening=True, + ) + + init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0) + assert isinstance(init, InitializeResponse) + + new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0) + assert isinstance(new_session, NewSessionResponse) + + load_session = await asyncio.wait_for( + agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id), + timeout=1.0, + ) + assert isinstance(load_session, LoadSessionResponse) + + assert captured_agent, "Agent was not constructed" + [agent] = captured_agent + assert agent.seen_new_session == ("/workspace", None) + assert agent.seen_load_session == ("/workspace", new_session.session_id, None) From 1f21da8d68fa73bc49a0c2c4bc78d3cc60063907 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 11 Jan 2026 19:40:34 +0000 Subject: [PATCH 2/5] fix: warn on positional load_session session_id --- src/acp/client/connection.py | 7 +++++++ tests/real_user/test_issue_55_mcp_servers_optional.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index dab450b..66854ea 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import warnings from collections.abc import Callable from typing import Any, cast, final @@ -116,6 +117,12 @@ async def load_session( ) -> LoadSessionResponse: if session_id is _MISSING: if isinstance(mcp_servers, str): + warnings.warn( + "Passing session_id as the second positional argument to load_session() is deprecated; " + "use load_session(cwd=..., session_id=..., mcp_servers=...) instead.", + DeprecationWarning, + stacklevel=2, + ) session_id = mcp_servers mcp_servers = None else: diff --git a/tests/real_user/test_issue_55_mcp_servers_optional.py b/tests/real_user/test_issue_55_mcp_servers_optional.py index 52ea330..230ceaa 100644 --- a/tests/real_user/test_issue_55_mcp_servers_optional.py +++ b/tests/real_user/test_issue_55_mcp_servers_optional.py @@ -62,6 +62,13 @@ async def test_session_requests_allow_missing_mcp_servers(server) -> None: ) assert isinstance(load_session, LoadSessionResponse) + with pytest.warns(DeprecationWarning): + load_session = await asyncio.wait_for( + agent_conn.load_session("/workspace", new_session.session_id), + timeout=1.0, + ) + assert isinstance(load_session, LoadSessionResponse) + assert captured_agent, "Agent was not constructed" [agent] = captured_agent assert agent.seen_new_session == ("/workspace", None) From f24134453c36a120b5ee0a9d09a9bea1b7ce92fd Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sat, 7 Feb 2026 16:14:10 +0000 Subject: [PATCH 3/5] Revert "fix: warn on positional load_session session_id" This reverts commit 1f21da8d68fa73bc49a0c2c4bc78d3cc60063907. --- src/acp/client/connection.py | 7 ------- tests/real_user/test_issue_55_mcp_servers_optional.py | 7 ------- 2 files changed, 14 deletions(-) diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index 66854ea..dab450b 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import warnings from collections.abc import Callable from typing import Any, cast, final @@ -117,12 +116,6 @@ async def load_session( ) -> LoadSessionResponse: if session_id is _MISSING: if isinstance(mcp_servers, str): - warnings.warn( - "Passing session_id as the second positional argument to load_session() is deprecated; " - "use load_session(cwd=..., session_id=..., mcp_servers=...) instead.", - DeprecationWarning, - stacklevel=2, - ) session_id = mcp_servers mcp_servers = None else: diff --git a/tests/real_user/test_issue_55_mcp_servers_optional.py b/tests/real_user/test_issue_55_mcp_servers_optional.py index 230ceaa..52ea330 100644 --- a/tests/real_user/test_issue_55_mcp_servers_optional.py +++ b/tests/real_user/test_issue_55_mcp_servers_optional.py @@ -62,13 +62,6 @@ async def test_session_requests_allow_missing_mcp_servers(server) -> None: ) assert isinstance(load_session, LoadSessionResponse) - with pytest.warns(DeprecationWarning): - load_session = await asyncio.wait_for( - agent_conn.load_session("/workspace", new_session.session_id), - timeout=1.0, - ) - assert isinstance(load_session, LoadSessionResponse) - assert captured_agent, "Agent was not constructed" [agent] = captured_agent assert agent.seen_new_session == ("/workspace", None) From 58bf3539f6eb7a97ac761bcf81e494693c4cacf7 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sat, 7 Feb 2026 16:14:12 +0000 Subject: [PATCH 4/5] Revert "fix: allow omitting mcpServers in session requests" This reverts commit 73499a8f2ef7a0f32a0cf89e9aa6f973d66b976c. --- scripts/gen_schema.py | 33 ++------- scripts/schema_patches.py | 51 -------------- src/acp/client/connection.py | 25 +------ src/acp/interfaces.py | 11 +-- src/acp/schema.py | 8 +-- tests/conftest.py | 11 +-- .../test_issue_55_mcp_servers_optional.py | 68 ------------------- 7 files changed, 16 insertions(+), 191 deletions(-) delete mode 100644 scripts/schema_patches.py delete mode 100644 tests/real_user/test_issue_55_mcp_servers_optional.py diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 5a505a6..74a8790 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -2,23 +2,16 @@ from __future__ import annotations import ast -import contextlib import json import re import subprocess import sys -import tempfile import textwrap from collections.abc import Callable from dataclasses import dataclass from pathlib import Path ROOT = Path(__file__).resolve().parents[1] -if str(ROOT) not in sys.path: - sys.path.append(str(ROOT)) - -from scripts.schema_patches import apply_schema_patches # noqa: E402 - SCHEMA_DIR = ROOT / "schema" SCHEMA_JSON = SCHEMA_DIR / "schema.json" VERSION_FILE = SCHEMA_DIR / "VERSION" @@ -143,23 +136,12 @@ def generate_schema() -> None: ) sys.exit(1) - schema_payload = json.loads(SCHEMA_JSON.read_text(encoding="utf-8")) - schema_payload, patch_warnings = apply_schema_patches(schema_payload) - for warning in patch_warnings: - print(f"Warning: {warning.message}", file=sys.stderr) - - patched_schema_path: Path | None = None - with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as handle: - json.dump(schema_payload, handle, indent=2) - handle.write("\n") - patched_schema_path = Path(handle.name) - cmd = [ sys.executable, "-m", "datamodel_code_generator", "--input", - str(patched_schema_path), + str(SCHEMA_JSON), "--input-file-type", "jsonschema", "--output", @@ -173,15 +155,10 @@ def generate_schema() -> None: "--snake-case-field", ] - try: - subprocess.check_call(cmd) # noqa: S603 - warnings = postprocess_generated_schema(SCHEMA_OUT) - for warning in warnings: - print(f"Warning: {warning}", file=sys.stderr) - finally: - if patched_schema_path is not None: - with contextlib.suppress(OSError): - patched_schema_path.unlink() + subprocess.check_call(cmd) # noqa: S603 + warnings = postprocess_generated_schema(SCHEMA_OUT) + for warning in warnings: + print(f"Warning: {warning}", file=sys.stderr) def postprocess_generated_schema(output_path: Path) -> list[str]: diff --git a/scripts/schema_patches.py b/scripts/schema_patches.py deleted file mode 100644 index c2f828a..0000000 --- a/scripts/schema_patches.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - - -@dataclass(frozen=True, slots=True) -class PatchWarning: - message: str - - -def apply_schema_patches(schema: dict[str, Any]) -> tuple[dict[str, Any], list[PatchWarning]]: - patched = schema - warnings: list[PatchWarning] = [] - - patched, warning = _make_defs_field_optional(patched, "NewSessionRequest", "mcpServers") - if warning is not None: - warnings.append(warning) - - patched, warning = _make_defs_field_optional(patched, "LoadSessionRequest", "mcpServers") - if warning is not None: - warnings.append(warning) - - return patched, warnings - - -def _make_defs_field_optional( - schema: dict[str, Any], - model_name: str, - field_name: str, -) -> tuple[dict[str, Any], PatchWarning | None]: - defs = schema.get("$defs") - if not isinstance(defs, dict): - return schema, PatchWarning("schema.$defs missing or invalid; cannot apply patches") - - model = defs.get(model_name) - if not isinstance(model, dict): - return schema, PatchWarning(f"schema.$defs.{model_name} missing or invalid; cannot patch {field_name}") - - required = model.get("required") - if required is None: - return schema, None - if not isinstance(required, list): - return schema, PatchWarning(f"schema.$defs.{model_name}.required invalid; cannot patch {field_name}") - - new_required = [item for item in required if item != field_name] - if new_required == required: - return schema, None - - model["required"] = new_required - return schema, None diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index dab450b..ac0d34f 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -45,7 +45,6 @@ __all__ = ["ClientSideConnection"] _CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader" -_MISSING = object() @final @@ -94,10 +93,7 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, - cwd: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, - **kwargs: Any, + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any ) -> NewSessionResponse: return await request_model( self._conn, @@ -108,27 +104,12 @@ async def new_session( @param_model(LoadSessionRequest) async def load_session( - self, - cwd: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | str | None = None, - session_id: str | object = _MISSING, - **kwargs: Any, + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any ) -> LoadSessionResponse: - if session_id is _MISSING: - if isinstance(mcp_servers, str): - session_id = mcp_servers - mcp_servers = None - else: - raise TypeError("load_session() missing required argument: 'session_id'") return await request_model_from_dict( self._conn, AGENT_METHODS["session_load"], - LoadSessionRequest( - cwd=cwd, - mcp_servers=cast(list[HttpMcpServer | SseMcpServer | McpServerStdio] | None, mcp_servers), - session_id=cast(str, session_id), - field_meta=kwargs or None, - ), + LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None), LoadSessionResponse, ) diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 3bf8eb4..457dfe7 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -154,19 +154,12 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, - cwd: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, - **kwargs: Any, + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any ) -> NewSessionResponse: ... @param_model(LoadSessionRequest) async def load_session( - self, - cwd: str, - session_id: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, - **kwargs: Any, + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any ) -> LoadSessionResponse | None: ... @param_model(ListSessionsRequest) diff --git a/src/acp/schema.py b/src/acp/schema.py index b0247fe..e449e4a 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -1410,12 +1410,12 @@ class NewSessionRequest(BaseModel): ] # List of MCP (Model Context Protocol) servers the agent should connect to. mcp_servers: Annotated[ - Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]], + List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]], Field( alias="mcpServers", description="List of MCP (Model Context Protocol) servers the agent should connect to.", ), - ] = None + ] class PermissionOption(BaseModel): @@ -2073,12 +2073,12 @@ class LoadSessionRequest(BaseModel): cwd: Annotated[str, Field(description="The working directory for this session.")] # List of MCP servers to connect to for this session. mcp_servers: Annotated[ - Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]], + List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]], Field( alias="mcpServers", description="List of MCP servers to connect to for this session.", ), - ] = None + ] # The ID of the session to load. session_id: Annotated[str, Field(alias="sessionId", description="The ID of the session to load.")] diff --git a/tests/conftest.py b/tests/conftest.py index 266fa71..6cce0b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -243,19 +243,12 @@ async def initialize( return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[]) async def new_session( - self, - cwd: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, - **kwargs: Any, + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any ) -> NewSessionResponse: return NewSessionResponse(session_id="test-session-123") async def load_session( - self, - cwd: str, - session_id: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, - **kwargs: Any, + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any ) -> LoadSessionResponse | None: return LoadSessionResponse() diff --git a/tests/real_user/test_issue_55_mcp_servers_optional.py b/tests/real_user/test_issue_55_mcp_servers_optional.py deleted file mode 100644 index 52ea330..0000000 --- a/tests/real_user/test_issue_55_mcp_servers_optional.py +++ /dev/null @@ -1,68 +0,0 @@ -import asyncio -from typing import Any - -import pytest - -from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse -from acp.core import AgentSideConnection, ClientSideConnection -from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer -from tests.conftest import TestAgent, TestClient - -# Regression from a real-world client run where `mcpServers` is omitted from session requests. - - -class Issue55Agent(TestAgent): - def __init__(self) -> None: - super().__init__() - self.seen_new_session: tuple[str, Any] | None = None - self.seen_load_session: tuple[str, str, Any] | None = None - - async def new_session( - self, - cwd: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, - **kwargs: Any, - ) -> NewSessionResponse: - self.seen_new_session = (cwd, mcp_servers) - return await super().new_session(cwd=cwd, mcp_servers=mcp_servers, **kwargs) - - async def load_session( - self, - cwd: str, - session_id: str, - mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, - **kwargs: Any, - ) -> LoadSessionResponse | None: - self.seen_load_session = (cwd, session_id, mcp_servers) - return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=mcp_servers, **kwargs) - - -@pytest.mark.asyncio -async def test_session_requests_allow_missing_mcp_servers(server) -> None: - client = TestClient() - captured_agent: list[Issue55Agent] = [] - - agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type] - _agent_side = AgentSideConnection( - lambda _conn: captured_agent.append(Issue55Agent()) or captured_agent[-1], - server._server_writer, - server._server_reader, - listening=True, - ) - - init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0) - assert isinstance(init, InitializeResponse) - - new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0) - assert isinstance(new_session, NewSessionResponse) - - load_session = await asyncio.wait_for( - agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id), - timeout=1.0, - ) - assert isinstance(load_session, LoadSessionResponse) - - assert captured_agent, "Agent was not constructed" - [agent] = captured_agent - assert agent.seen_new_session == ("/workspace", None) - assert agent.seen_load_session == ("/workspace", new_session.session_id, None) From 283fdc0a450365a278f20ed8d504fb771e04a18b Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sat, 7 Feb 2026 16:25:53 +0000 Subject: [PATCH 5/5] fix: default mcpServers to empty list --- scripts/gen_signature.py | 30 +++++++-- src/acp/client/connection.py | 16 +++-- src/acp/interfaces.py | 8 ++- tests/real_user/test_mcp_servers_optional.py | 68 ++++++++++++++++++++ 4 files changed, 109 insertions(+), 13 deletions(-) create mode 100644 tests/real_user/test_mcp_servers_optional.py diff --git a/scripts/gen_signature.py b/scripts/gen_signature.py index b3a7add..b435e2c 100644 --- a/scripts/gen_signature.py +++ b/scripts/gen_signature.py @@ -9,6 +9,11 @@ from acp import schema +SIGNATURE_OPTIONAL_FIELDS: set[tuple[str, str]] = { + ("LoadSessionRequest", "mcp_servers"), + ("NewSessionRequest", "mcp_servers"), +} + class NodeTransformer(ast.NodeTransformer): def __init__(self) -> None: @@ -16,6 +21,7 @@ def __init__(self) -> None: self._schema_import_node: ast.ImportFrom | None = None self._should_rewrite = False self._literals = {name: value for name, value in schema.__dict__.items() if t.get_origin(value) is t.Literal} + self._current_model_name: str | None = None def _add_typing_import(self, name: str) -> None: if not self._type_import_node: @@ -71,9 +77,13 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST: self._should_rewrite = True model_name = t.cast(ast.Name, decorator.args[0]).id model = t.cast(type[schema.BaseModel], getattr(schema, model_name)) - param_defaults = [ - self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta" - ] + self._current_model_name = model_name + try: + param_defaults = [ + self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta" + ] + finally: + self._current_model_name = None param_defaults.sort(key=lambda x: x[1] is not None) node.args.args[1:] = [param for param, _ in param_defaults] node.args.defaults = [default for _, default in param_defaults if default is not None] @@ -84,12 +94,18 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST: def _to_param_def(self, name: str, field: FieldInfo) -> tuple[ast.arg, ast.expr | None]: arg = ast.arg(arg=name) ann = field.annotation - if field.default is PydanticUndefined: - default = None - elif isinstance(field.default, dict | BaseModel): + override_optional = (self._current_model_name, name) in SIGNATURE_OPTIONAL_FIELDS + if override_optional: + if ann is not None: + ann = ann | None default = ast.Constant(None) else: - default = ast.Constant(value=field.default) + if field.default is PydanticUndefined: + default = None + elif isinstance(field.default, dict | BaseModel): + default = ast.Constant(None) + else: + default = ast.Constant(value=field.default) if ann is not None: arg.annotation = self._format_annotation(ann) return arg, default diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index ac0d34f..9831d7e 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -93,23 +93,31 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any ) -> NewSessionResponse: + resolved_mcp_servers = mcp_servers or [] return await request_model( self._conn, AGENT_METHODS["session_new"], - NewSessionRequest(cwd=cwd, mcp_servers=mcp_servers, field_meta=kwargs or None), + NewSessionRequest(cwd=cwd, mcp_servers=resolved_mcp_servers, field_meta=kwargs or None), NewSessionResponse, ) @param_model(LoadSessionRequest) async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> LoadSessionResponse: + resolved_mcp_servers = mcp_servers or [] return await request_model_from_dict( self._conn, AGENT_METHODS["session_load"], - LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None), + LoadSessionRequest( + cwd=cwd, mcp_servers=resolved_mcp_servers, session_id=session_id, field_meta=kwargs or None + ), LoadSessionResponse, ) diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 457dfe7..55c00f3 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -154,12 +154,16 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any ) -> NewSessionResponse: ... @param_model(LoadSessionRequest) async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> LoadSessionResponse | None: ... @param_model(ListSessionsRequest) diff --git a/tests/real_user/test_mcp_servers_optional.py b/tests/real_user/test_mcp_servers_optional.py new file mode 100644 index 0000000..96aae75 --- /dev/null +++ b/tests/real_user/test_mcp_servers_optional.py @@ -0,0 +1,68 @@ +import asyncio +from typing import Any + +import pytest + +from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse +from acp.core import AgentSideConnection, ClientSideConnection +from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer +from tests.conftest import TestAgent, TestClient + + +class McpOptionalAgent(TestAgent): + def __init__(self) -> None: + super().__init__() + self.seen_new_session: tuple[str, Any] | None = None + self.seen_load_session: tuple[str, str, Any] | None = None + + async def new_session( + self, + cwd: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, + ) -> NewSessionResponse: + resolved_mcp_servers = mcp_servers or [] + self.seen_new_session = (cwd, resolved_mcp_servers) + return await super().new_session(cwd=cwd, mcp_servers=resolved_mcp_servers, **kwargs) + + async def load_session( + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, + ) -> LoadSessionResponse | None: + resolved_mcp_servers = mcp_servers or [] + self.seen_load_session = (cwd, session_id, resolved_mcp_servers) + return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=resolved_mcp_servers, **kwargs) + + +@pytest.mark.asyncio +async def test_session_requests_default_empty_mcp_servers(server) -> None: + client = TestClient() + captured_agent: list[McpOptionalAgent] = [] + + agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type] + _agent_side = AgentSideConnection( + lambda _conn: captured_agent.append(McpOptionalAgent()) or captured_agent[-1], + server._server_writer, + server._server_reader, + listening=True, + ) + + init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0) + assert isinstance(init, InitializeResponse) + + new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0) + assert isinstance(new_session, NewSessionResponse) + + load_session = await asyncio.wait_for( + agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id), + timeout=1.0, + ) + assert isinstance(load_session, LoadSessionResponse) + + assert captured_agent, "Agent was not constructed" + [agent] = captured_agent + assert agent.seen_new_session == ("/workspace", []) + assert agent.seen_load_session == ("/workspace", new_session.session_id, [])