Skip to content

Commit e7b2f44

Browse files
committed
feat: add reconnect_mcp_server and toggle_mcp_server client methods
1 parent a58d3ab commit e7b2f44

4 files changed

Lines changed: 265 additions & 0 deletions

File tree

src/claude_agent_sdk/_internal/query.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,34 @@ async def rewind_files(self, user_message_id: str) -> None:
567567
}
568568
)
569569

570+
async def reconnect_mcp_server(self, server_name: str) -> None:
571+
"""Reconnect a disconnected or failed MCP server.
572+
573+
Args:
574+
server_name: The name of the MCP server to reconnect
575+
"""
576+
await self._send_control_request(
577+
{
578+
"subtype": "mcp_reconnect",
579+
"serverName": server_name,
580+
}
581+
)
582+
583+
async def toggle_mcp_server(self, server_name: str, enabled: bool) -> None:
584+
"""Enable or disable an MCP server.
585+
586+
Args:
587+
server_name: The name of the MCP server to toggle
588+
enabled: Whether the server should be enabled
589+
"""
590+
await self._send_control_request(
591+
{
592+
"subtype": "mcp_toggle",
593+
"serverName": server_name,
594+
"enabled": enabled,
595+
}
596+
)
597+
570598
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
571599
"""Stream input messages to transport.
572600

src/claude_agent_sdk/client.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,54 @@ async def rewind_files(self, user_message_id: str) -> None:
304304
raise CLIConnectionError("Not connected. Call connect() first.")
305305
await self._query.rewind_files(user_message_id)
306306

307+
async def reconnect_mcp_server(self, server_name: str) -> None:
308+
"""Reconnect a disconnected or failed MCP server (only works with streaming mode).
309+
310+
Use this to retry connecting to an MCP server that failed to connect
311+
or was disconnected. Raises an exception if the reconnection fails.
312+
313+
Args:
314+
server_name: The name of the MCP server to reconnect
315+
316+
Example:
317+
```python
318+
async with ClaudeSDKClient(options) as client:
319+
status = await client.get_mcp_status()
320+
for server in status.get("mcpServers", []):
321+
if server["status"] == "failed":
322+
await client.reconnect_mcp_server(server["name"])
323+
```
324+
"""
325+
if not self._query:
326+
raise CLIConnectionError("Not connected. Call connect() first.")
327+
await self._query.reconnect_mcp_server(server_name)
328+
329+
async def toggle_mcp_server(self, server_name: str, enabled: bool) -> None:
330+
"""Enable or disable an MCP server (only works with streaming mode).
331+
332+
Disabling a server disconnects it and removes its tools from the
333+
available tool set. Enabling a server reconnects it and makes its
334+
tools available again. Raises an exception on failure.
335+
336+
Args:
337+
server_name: The name of the MCP server to toggle
338+
enabled: True to enable the server, False to disable it
339+
340+
Example:
341+
```python
342+
async with ClaudeSDKClient(options) as client:
343+
# Temporarily disable a server
344+
await client.toggle_mcp_server("my-server", enabled=False)
345+
await client.query("Do something without my-server tools")
346+
347+
# Re-enable it later
348+
await client.toggle_mcp_server("my-server", enabled=True)
349+
```
350+
"""
351+
if not self._query:
352+
raise CLIConnectionError("Not connected. Call connect() first.")
353+
await self._query.toggle_mcp_server(server_name, enabled)
354+
307355
async def get_mcp_status(self) -> dict[str, Any]:
308356
"""Get current MCP server connection status (only works with streaming mode).
309357

src/claude_agent_sdk/types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,23 @@ class SDKControlRewindFilesRequest(TypedDict):
828828
user_message_id: str
829829

830830

831+
class SDKControlMcpReconnectRequest(TypedDict):
832+
"""Reconnects a disconnected or failed MCP server."""
833+
834+
subtype: Literal["mcp_reconnect"]
835+
# Note: wire protocol uses camelCase for this field
836+
serverName: str
837+
838+
839+
class SDKControlMcpToggleRequest(TypedDict):
840+
"""Enables or disables an MCP server."""
841+
842+
subtype: Literal["mcp_toggle"]
843+
# Note: wire protocol uses camelCase for this field
844+
serverName: str
845+
enabled: bool
846+
847+
831848
class SDKControlRequest(TypedDict):
832849
type: Literal["control_request"]
833850
request_id: str
@@ -839,6 +856,8 @@ class SDKControlRequest(TypedDict):
839856
| SDKHookCallbackRequest
840857
| SDKControlMcpMessageRequest
841858
| SDKControlRewindFilesRequest
859+
| SDKControlMcpReconnectRequest
860+
| SDKControlMcpToggleRequest
842861
)
843862

844863

tests/test_streaming_client.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,54 @@ async def control_protocol_generator():
107107
return mock_transport
108108

109109

110+
def _create_mock_transport_with_control_responses():
111+
"""Create a mock transport that responds with success to all control requests.
112+
113+
Useful for testing client methods that send control requests (e.g.
114+
reconnect_mcp_server, toggle_mcp_server) without needing to special-case
115+
each subtype in the mock.
116+
"""
117+
mock_transport = AsyncMock()
118+
mock_transport.connect = AsyncMock()
119+
mock_transport.close = AsyncMock()
120+
mock_transport.end_input = AsyncMock()
121+
mock_transport.is_ready = Mock(return_value=True)
122+
123+
written_messages: list[str] = []
124+
125+
async def mock_write(data):
126+
written_messages.append(data)
127+
128+
mock_transport.write = AsyncMock(side_effect=mock_write)
129+
130+
async def control_protocol_generator():
131+
# Poll for control requests and respond with success to each one.
132+
last_check = 0
133+
timeout_counter = 0
134+
while timeout_counter < 200: # Avoid infinite loop
135+
await asyncio.sleep(0.01)
136+
timeout_counter += 1
137+
138+
for msg_str in written_messages[last_check:]:
139+
try:
140+
msg = json.loads(msg_str.strip())
141+
if msg.get("type") == "control_request":
142+
yield {
143+
"type": "control_response",
144+
"response": {
145+
"request_id": msg.get("request_id"),
146+
"subtype": "success",
147+
"response": {},
148+
},
149+
}
150+
except (json.JSONDecodeError, KeyError, AttributeError):
151+
pass
152+
last_check = len(written_messages)
153+
154+
mock_transport.read_messages = control_protocol_generator
155+
return mock_transport
156+
157+
110158
class TestClaudeSDKClientStreaming:
111159
"""Test ClaudeSDKClient streaming functionality."""
112160

@@ -467,6 +515,128 @@ async def _test():
467515

468516
anyio.run(_test)
469517

518+
def test_reconnect_mcp_server(self):
519+
"""Test reconnect_mcp_server sends correct control request."""
520+
521+
async def _test():
522+
with patch(
523+
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
524+
) as mock_transport_class:
525+
mock_transport = _create_mock_transport_with_control_responses()
526+
mock_transport_class.return_value = mock_transport
527+
528+
async with ClaudeSDKClient() as client:
529+
await client.reconnect_mcp_server("my-server")
530+
# Check that a control request was sent via write
531+
write_calls = mock_transport.write.call_args_list
532+
request_found = False
533+
for call in write_calls:
534+
data = call[0][0]
535+
try:
536+
msg = json.loads(data.strip())
537+
req = msg.get("request", {})
538+
if (
539+
msg.get("type") == "control_request"
540+
and req.get("subtype") == "mcp_reconnect"
541+
):
542+
# Verify wire format uses camelCase serverName
543+
assert req.get("serverName") == "my-server"
544+
request_found = True
545+
break
546+
except (json.JSONDecodeError, KeyError, AttributeError):
547+
pass
548+
assert request_found, "mcp_reconnect control request not found"
549+
550+
anyio.run(_test)
551+
552+
def test_reconnect_mcp_server_not_connected(self):
553+
"""Test reconnect_mcp_server when not connected raises error."""
554+
555+
async def _test():
556+
client = ClaudeSDKClient()
557+
with pytest.raises(CLIConnectionError, match="Not connected"):
558+
await client.reconnect_mcp_server("my-server")
559+
560+
anyio.run(_test)
561+
562+
def test_toggle_mcp_server(self):
563+
"""Test toggle_mcp_server sends correct control request."""
564+
565+
async def _test():
566+
with patch(
567+
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
568+
) as mock_transport_class:
569+
mock_transport = _create_mock_transport_with_control_responses()
570+
mock_transport_class.return_value = mock_transport
571+
572+
async with ClaudeSDKClient() as client:
573+
await client.toggle_mcp_server("my-server", False)
574+
# Check that a control request was sent via write
575+
write_calls = mock_transport.write.call_args_list
576+
request_found = False
577+
for call in write_calls:
578+
data = call[0][0]
579+
try:
580+
msg = json.loads(data.strip())
581+
req = msg.get("request", {})
582+
if (
583+
msg.get("type") == "control_request"
584+
and req.get("subtype") == "mcp_toggle"
585+
):
586+
# Verify wire format uses camelCase serverName
587+
assert req.get("serverName") == "my-server"
588+
assert req.get("enabled") is False
589+
request_found = True
590+
break
591+
except (json.JSONDecodeError, KeyError, AttributeError):
592+
pass
593+
assert request_found, "mcp_toggle control request not found"
594+
595+
anyio.run(_test)
596+
597+
def test_toggle_mcp_server_enabled_true(self):
598+
"""Test toggle_mcp_server with enabled=True."""
599+
600+
async def _test():
601+
with patch(
602+
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
603+
) as mock_transport_class:
604+
mock_transport = _create_mock_transport_with_control_responses()
605+
mock_transport_class.return_value = mock_transport
606+
607+
async with ClaudeSDKClient() as client:
608+
await client.toggle_mcp_server("other-server", True)
609+
write_calls = mock_transport.write.call_args_list
610+
request_found = False
611+
for call in write_calls:
612+
data = call[0][0]
613+
try:
614+
msg = json.loads(data.strip())
615+
req = msg.get("request", {})
616+
if (
617+
msg.get("type") == "control_request"
618+
and req.get("subtype") == "mcp_toggle"
619+
):
620+
assert req.get("serverName") == "other-server"
621+
assert req.get("enabled") is True
622+
request_found = True
623+
break
624+
except (json.JSONDecodeError, KeyError, AttributeError):
625+
pass
626+
assert request_found, "mcp_toggle control request not found"
627+
628+
anyio.run(_test)
629+
630+
def test_toggle_mcp_server_not_connected(self):
631+
"""Test toggle_mcp_server when not connected raises error."""
632+
633+
async def _test():
634+
client = ClaudeSDKClient()
635+
with pytest.raises(CLIConnectionError, match="Not connected"):
636+
await client.toggle_mcp_server("my-server", True)
637+
638+
anyio.run(_test)
639+
470640
def test_client_with_options(self):
471641
"""Test client initialization with options."""
472642

0 commit comments

Comments
 (0)