Skip to content

Commit 73499a8

Browse files
committed
fix: allow omitting mcpServers in session requests
1 parent b4f253c commit 73499a8

File tree

7 files changed

+191
-16
lines changed

7 files changed

+191
-16
lines changed

scripts/gen_schema.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
from __future__ import annotations
33

44
import ast
5+
import contextlib
56
import json
67
import re
78
import subprocess
89
import sys
10+
import tempfile
911
import textwrap
1012
from collections.abc import Callable
1113
from dataclasses import dataclass
1214
from pathlib import Path
1315

1416
ROOT = Path(__file__).resolve().parents[1]
17+
if str(ROOT) not in sys.path:
18+
sys.path.append(str(ROOT))
19+
20+
from scripts.schema_patches import apply_schema_patches # noqa: E402
21+
1522
SCHEMA_DIR = ROOT / "schema"
1623
SCHEMA_JSON = SCHEMA_DIR / "schema.json"
1724
VERSION_FILE = SCHEMA_DIR / "VERSION"
@@ -136,12 +143,23 @@ def generate_schema() -> None:
136143
)
137144
sys.exit(1)
138145

146+
schema_payload = json.loads(SCHEMA_JSON.read_text(encoding="utf-8"))
147+
schema_payload, patch_warnings = apply_schema_patches(schema_payload)
148+
for warning in patch_warnings:
149+
print(f"Warning: {warning.message}", file=sys.stderr)
150+
151+
patched_schema_path: Path | None = None
152+
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as handle:
153+
json.dump(schema_payload, handle, indent=2)
154+
handle.write("\n")
155+
patched_schema_path = Path(handle.name)
156+
139157
cmd = [
140158
sys.executable,
141159
"-m",
142160
"datamodel_code_generator",
143161
"--input",
144-
str(SCHEMA_JSON),
162+
str(patched_schema_path),
145163
"--input-file-type",
146164
"jsonschema",
147165
"--output",
@@ -155,10 +173,15 @@ def generate_schema() -> None:
155173
"--snake-case-field",
156174
]
157175

158-
subprocess.check_call(cmd) # noqa: S603
159-
warnings = postprocess_generated_schema(SCHEMA_OUT)
160-
for warning in warnings:
161-
print(f"Warning: {warning}", file=sys.stderr)
176+
try:
177+
subprocess.check_call(cmd) # noqa: S603
178+
warnings = postprocess_generated_schema(SCHEMA_OUT)
179+
for warning in warnings:
180+
print(f"Warning: {warning}", file=sys.stderr)
181+
finally:
182+
if patched_schema_path is not None:
183+
with contextlib.suppress(OSError):
184+
patched_schema_path.unlink()
162185

163186

164187
def postprocess_generated_schema(output_path: Path) -> list[str]:

scripts/schema_patches.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any
5+
6+
7+
@dataclass(frozen=True, slots=True)
8+
class PatchWarning:
9+
message: str
10+
11+
12+
def apply_schema_patches(schema: dict[str, Any]) -> tuple[dict[str, Any], list[PatchWarning]]:
13+
patched = schema
14+
warnings: list[PatchWarning] = []
15+
16+
patched, warning = _make_defs_field_optional(patched, "NewSessionRequest", "mcpServers")
17+
if warning is not None:
18+
warnings.append(warning)
19+
20+
patched, warning = _make_defs_field_optional(patched, "LoadSessionRequest", "mcpServers")
21+
if warning is not None:
22+
warnings.append(warning)
23+
24+
return patched, warnings
25+
26+
27+
def _make_defs_field_optional(
28+
schema: dict[str, Any],
29+
model_name: str,
30+
field_name: str,
31+
) -> tuple[dict[str, Any], PatchWarning | None]:
32+
defs = schema.get("$defs")
33+
if not isinstance(defs, dict):
34+
return schema, PatchWarning("schema.$defs missing or invalid; cannot apply patches")
35+
36+
model = defs.get(model_name)
37+
if not isinstance(model, dict):
38+
return schema, PatchWarning(f"schema.$defs.{model_name} missing or invalid; cannot patch {field_name}")
39+
40+
required = model.get("required")
41+
if required is None:
42+
return schema, None
43+
if not isinstance(required, list):
44+
return schema, PatchWarning(f"schema.$defs.{model_name}.required invalid; cannot patch {field_name}")
45+
46+
new_required = [item for item in required if item != field_name]
47+
if new_required == required:
48+
return schema, None
49+
50+
model["required"] = new_required
51+
return schema, None

src/acp/client/connection.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
__all__ = ["ClientSideConnection"]
4747
_CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader"
48+
_MISSING = object()
4849

4950

5051
@final
@@ -93,7 +94,10 @@ async def initialize(
9394

9495
@param_model(NewSessionRequest)
9596
async def new_session(
96-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
97+
self,
98+
cwd: str,
99+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
100+
**kwargs: Any,
97101
) -> NewSessionResponse:
98102
return await request_model(
99103
self._conn,
@@ -104,12 +108,27 @@ async def new_session(
104108

105109
@param_model(LoadSessionRequest)
106110
async def load_session(
107-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
111+
self,
112+
cwd: str,
113+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | str | None = None,
114+
session_id: str | object = _MISSING,
115+
**kwargs: Any,
108116
) -> LoadSessionResponse:
117+
if session_id is _MISSING:
118+
if isinstance(mcp_servers, str):
119+
session_id = mcp_servers
120+
mcp_servers = None
121+
else:
122+
raise TypeError("load_session() missing required argument: 'session_id'")
109123
return await request_model_from_dict(
110124
self._conn,
111125
AGENT_METHODS["session_load"],
112-
LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None),
126+
LoadSessionRequest(
127+
cwd=cwd,
128+
mcp_servers=cast(list[HttpMcpServer | SseMcpServer | McpServerStdio] | None, mcp_servers),
129+
session_id=cast(str, session_id),
130+
field_meta=kwargs or None,
131+
),
113132
LoadSessionResponse,
114133
)
115134

src/acp/interfaces.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,19 @@ async def initialize(
154154

155155
@param_model(NewSessionRequest)
156156
async def new_session(
157-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
157+
self,
158+
cwd: str,
159+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
160+
**kwargs: Any,
158161
) -> NewSessionResponse: ...
159162

160163
@param_model(LoadSessionRequest)
161164
async def load_session(
162-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
165+
self,
166+
cwd: str,
167+
session_id: str,
168+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
169+
**kwargs: Any,
163170
) -> LoadSessionResponse | None: ...
164171

165172
@param_model(ListSessionsRequest)

src/acp/schema.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,12 +1410,12 @@ class NewSessionRequest(BaseModel):
14101410
]
14111411
# List of MCP (Model Context Protocol) servers the agent should connect to.
14121412
mcp_servers: Annotated[
1413-
List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]],
1413+
Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]],
14141414
Field(
14151415
alias="mcpServers",
14161416
description="List of MCP (Model Context Protocol) servers the agent should connect to.",
14171417
),
1418-
]
1418+
] = None
14191419

14201420

14211421
class PermissionOption(BaseModel):
@@ -2073,12 +2073,12 @@ class LoadSessionRequest(BaseModel):
20732073
cwd: Annotated[str, Field(description="The working directory for this session.")]
20742074
# List of MCP servers to connect to for this session.
20752075
mcp_servers: Annotated[
2076-
List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]],
2076+
Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]],
20772077
Field(
20782078
alias="mcpServers",
20792079
description="List of MCP servers to connect to for this session.",
20802080
),
2081-
]
2081+
] = None
20822082
# The ID of the session to load.
20832083
session_id: Annotated[str, Field(alias="sessionId", description="The ID of the session to load.")]
20842084

tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,19 @@ async def initialize(
243243
return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[])
244244

245245
async def new_session(
246-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
246+
self,
247+
cwd: str,
248+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
249+
**kwargs: Any,
247250
) -> NewSessionResponse:
248251
return NewSessionResponse(session_id="test-session-123")
249252

250253
async def load_session(
251-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
254+
self,
255+
cwd: str,
256+
session_id: str,
257+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
258+
**kwargs: Any,
252259
) -> LoadSessionResponse | None:
253260
return LoadSessionResponse()
254261

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import asyncio
2+
from typing import Any
3+
4+
import pytest
5+
6+
from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse
7+
from acp.core import AgentSideConnection, ClientSideConnection
8+
from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer
9+
from tests.conftest import TestAgent, TestClient
10+
11+
# Regression from a real-world client run where `mcpServers` is omitted from session requests.
12+
13+
14+
class Issue55Agent(TestAgent):
15+
def __init__(self) -> None:
16+
super().__init__()
17+
self.seen_new_session: tuple[str, Any] | None = None
18+
self.seen_load_session: tuple[str, str, Any] | None = None
19+
20+
async def new_session(
21+
self,
22+
cwd: str,
23+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
24+
**kwargs: Any,
25+
) -> NewSessionResponse:
26+
self.seen_new_session = (cwd, mcp_servers)
27+
return await super().new_session(cwd=cwd, mcp_servers=mcp_servers, **kwargs)
28+
29+
async def load_session(
30+
self,
31+
cwd: str,
32+
session_id: str,
33+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
34+
**kwargs: Any,
35+
) -> LoadSessionResponse | None:
36+
self.seen_load_session = (cwd, session_id, mcp_servers)
37+
return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=mcp_servers, **kwargs)
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_session_requests_allow_missing_mcp_servers(server) -> None:
42+
client = TestClient()
43+
captured_agent: list[Issue55Agent] = []
44+
45+
agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type]
46+
_agent_side = AgentSideConnection(
47+
lambda _conn: captured_agent.append(Issue55Agent()) or captured_agent[-1],
48+
server._server_writer,
49+
server._server_reader,
50+
listening=True,
51+
)
52+
53+
init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0)
54+
assert isinstance(init, InitializeResponse)
55+
56+
new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0)
57+
assert isinstance(new_session, NewSessionResponse)
58+
59+
load_session = await asyncio.wait_for(
60+
agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id),
61+
timeout=1.0,
62+
)
63+
assert isinstance(load_session, LoadSessionResponse)
64+
65+
assert captured_agent, "Agent was not constructed"
66+
[agent] = captured_agent
67+
assert agent.seen_new_session == ("/workspace", None)
68+
assert agent.seen_load_session == ("/workspace", new_session.session_id, None)

0 commit comments

Comments
 (0)