diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 379d6d893..941eed9f0 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -31,6 +31,13 @@ HookMatcher, McpSdkServerConfig, McpServerConfig, + McpServerConnectionStatus, + McpServerInfo, + McpServerStatus, + McpServerStatusConfig, + McpStatusResponse, + McpToolAnnotations, + McpToolInfo, Message, NotificationHookInput, NotificationHookSpecificOutput, @@ -330,6 +337,13 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "PermissionMode", "McpServerConfig", "McpSdkServerConfig", + "McpServerStatus", + "McpServerStatusConfig", + "McpServerConnectionStatus", + "McpServerInfo", + "McpStatusResponse", + "McpToolAnnotations", + "McpToolInfo", "UserMessage", "AssistantMessage", "SystemMessage", diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 8f2784286..bbecb3e3a 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -567,6 +567,47 @@ async def rewind_files(self, user_message_id: str) -> None: } ) + async def reconnect_mcp_server(self, server_name: str) -> None: + """Reconnect a disconnected or failed MCP server. + + Args: + server_name: The name of the MCP server to reconnect + """ + await self._send_control_request( + { + "subtype": "mcp_reconnect", + "serverName": server_name, + } + ) + + async def toggle_mcp_server(self, server_name: str, enabled: bool) -> None: + """Enable or disable an MCP server. + + Args: + server_name: The name of the MCP server to toggle + enabled: Whether the server should be enabled + """ + await self._send_control_request( + { + "subtype": "mcp_toggle", + "serverName": server_name, + "enabled": enabled, + } + ) + + async def stop_task(self, task_id: str) -> None: + """Stop a running task. + + Args: + task_id: The task ID from task_notification events + """ + await self._send_control_request( + { + "subtype": "stop_task", + "task_id": task_id, + } + ) + async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None: """Stream input messages to transport. diff --git a/src/claude_agent_sdk/client.py b/src/claude_agent_sdk/client.py index 490bc4a29..fc7b6754d 100644 --- a/src/claude_agent_sdk/client.py +++ b/src/claude_agent_sdk/client.py @@ -8,7 +8,14 @@ from . import Transport from ._errors import CLIConnectionError -from .types import ClaudeAgentOptions, HookEvent, HookMatcher, Message, ResultMessage +from .types import ( + ClaudeAgentOptions, + HookEvent, + HookMatcher, + McpStatusResponse, + Message, + ResultMessage, +) class ClaudeSDKClient: @@ -304,30 +311,108 @@ async def rewind_files(self, user_message_id: str) -> None: raise CLIConnectionError("Not connected. Call connect() first.") await self._query.rewind_files(user_message_id) - async def get_mcp_status(self) -> dict[str, Any]: + async def reconnect_mcp_server(self, server_name: str) -> None: + """Reconnect a disconnected or failed MCP server (only works with streaming mode). + + Use this to retry connecting to an MCP server that failed to connect + or was disconnected. Raises an exception if the reconnection fails. + + Args: + server_name: The name of the MCP server to reconnect + + Example: + ```python + async with ClaudeSDKClient(options) as client: + status = await client.get_mcp_status() + for server in status.get("mcpServers", []): + if server["status"] == "failed": + await client.reconnect_mcp_server(server["name"]) + ``` + """ + if not self._query: + raise CLIConnectionError("Not connected. Call connect() first.") + await self._query.reconnect_mcp_server(server_name) + + async def toggle_mcp_server(self, server_name: str, enabled: bool) -> None: + """Enable or disable an MCP server (only works with streaming mode). + + Disabling a server disconnects it and removes its tools from the + available tool set. Enabling a server reconnects it and makes its + tools available again. Raises an exception on failure. + + Args: + server_name: The name of the MCP server to toggle + enabled: True to enable the server, False to disable it + + Example: + ```python + async with ClaudeSDKClient(options) as client: + # Temporarily disable a server + await client.toggle_mcp_server("my-server", enabled=False) + await client.query("Do something without my-server tools") + + # Re-enable it later + await client.toggle_mcp_server("my-server", enabled=True) + ``` + """ + if not self._query: + raise CLIConnectionError("Not connected. Call connect() first.") + await self._query.toggle_mcp_server(server_name, enabled) + + async def stop_task(self, task_id: str) -> None: + """Stop a running task (only works with streaming mode). + + After this resolves, a `task_notification` system message with + status `'stopped'` will be emitted by the CLI in the message stream. + + Args: + task_id: The task ID from `task_notification` events. + + Example: + ```python + async with ClaudeSDKClient() as client: + await client.query("Start a long-running task") + + # Listen for task_notification to get task_id, then: + await client.stop_task("task-abc123") + # A task_notification with status 'stopped' will follow + ``` + """ + if not self._query: + raise CLIConnectionError("Not connected. Call connect() first.") + await self._query.stop_task(task_id) + + async def get_mcp_status(self) -> McpStatusResponse: """Get current MCP server connection status (only works with streaming mode). Queries the Claude Code CLI for the live connection status of all configured MCP servers. Returns: - Dictionary with MCP server status information. Contains a - 'mcpServers' key with a list of server status objects, each having: + McpStatusResponse dictionary with an 'mcpServers' key containing + a list of McpServerStatus entries. Each entry includes: - 'name': Server name (str) - 'status': Connection status ('connected', 'pending', 'failed', 'needs-auth', 'disabled') + - 'serverInfo': MCP server name/version (when connected) + - 'error': Error message (when status is 'failed') + - 'config': Server configuration (stdio/sse/http/sdk/claudeai-proxy) + - 'scope': Configuration scope (e.g., project, user, local) + - 'tools': List of tools provided by the server (when connected) Example: ```python async with ClaudeSDKClient(options) as client: status = await client.get_mcp_status() - for server in status.get("mcpServers", []): + for server in status["mcpServers"]: print(f"{server['name']}: {server['status']}") + if server["status"] == "failed": + print(f" Error: {server.get('error')}") ``` """ if not self._query: raise CLIConnectionError("Not connected. Call connect() first.") - result: dict[str, Any] = await self._query.get_mcp_status() + result: McpStatusResponse = await self._query.get_mcp_status() return result async def get_server_info(self) -> dict[str, Any] | None: diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index 3ea89d5a4..c9650cccd 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -504,6 +504,116 @@ class McpSdkServerConfig(TypedDict): ) +# MCP Server Status types (returned by get_mcp_status) +# These mirror the TypeScript SDK's McpServerStatus type and use wire-format +# field names (camelCase where applicable) since they come directly from CLI +# JSON output. + + +class McpSdkServerConfigStatus(TypedDict): + """SDK MCP server config as returned in status responses. + + Unlike McpSdkServerConfig (which includes the in-process `instance`), + this output-only type only has serializable fields. + """ + + type: Literal["sdk"] + name: str + + +class McpClaudeAIProxyServerConfig(TypedDict): + """Claude.ai proxy MCP server config. + + Output-only type that appears in status responses for servers proxied + through Claude.ai. + """ + + type: Literal["claudeai-proxy"] + url: str + id: str + + +# Broader config type for status responses (includes claudeai-proxy which is +# output-only) +McpServerStatusConfig = ( + McpStdioServerConfig + | McpSSEServerConfig + | McpHttpServerConfig + | McpSdkServerConfigStatus + | McpClaudeAIProxyServerConfig +) + + +class McpToolAnnotations(TypedDict, total=False): + """Tool annotations as returned in MCP server status. + + Wire format uses camelCase field names (from CLI JSON output). + """ + + readOnly: bool + destructive: bool + openWorld: bool + + +class McpToolInfo(TypedDict): + """Information about a tool provided by an MCP server.""" + + name: str + description: NotRequired[str] + annotations: NotRequired[McpToolAnnotations] + + +class McpServerInfo(TypedDict): + """Server info from MCP initialize handshake (available when connected).""" + + name: str + version: str + + +# Connection status values for an MCP server +McpServerConnectionStatus = Literal[ + "connected", "failed", "needs-auth", "pending", "disabled" +] + + +class McpServerStatus(TypedDict): + """Status information for an MCP server connection. + + Returned by `ClaudeSDKClient.get_mcp_status()` in the `mcpServers` list. + """ + + name: str + """Server name as configured.""" + + status: McpServerConnectionStatus + """Current connection status.""" + + serverInfo: NotRequired[McpServerInfo] + """Server information from MCP handshake (available when connected).""" + + error: NotRequired[str] + """Error message (available when status is 'failed').""" + + config: NotRequired[McpServerStatusConfig] + """Server configuration (includes URL for HTTP/SSE servers).""" + + scope: NotRequired[str] + """Configuration scope (e.g., project, user, local, claudeai, managed).""" + + tools: NotRequired[list[McpToolInfo]] + """Tools provided by this server (available when connected).""" + + +class McpStatusResponse(TypedDict): + """Response from `ClaudeSDKClient.get_mcp_status()`. + + Wraps the list of server statuses under the `mcpServers` key, matching + the wire-format response shape. + """ + + mcpServers: list[McpServerStatus] + + class SdkPluginConfig(TypedDict): """SDK plugin configuration. @@ -828,6 +938,28 @@ class SDKControlRewindFilesRequest(TypedDict): user_message_id: str +class SDKControlMcpReconnectRequest(TypedDict): + """Reconnects a disconnected or failed MCP server.""" + + subtype: Literal["mcp_reconnect"] + # Note: wire protocol uses camelCase for this field + serverName: str + + +class SDKControlMcpToggleRequest(TypedDict): + """Enables or disables an MCP server.""" + + subtype: Literal["mcp_toggle"] + # Note: wire protocol uses camelCase for this field + serverName: str + enabled: bool + + +class SDKControlStopTaskRequest(TypedDict): + subtype: Literal["stop_task"] + task_id: str + + class SDKControlRequest(TypedDict): type: Literal["control_request"] request_id: str @@ -839,6 +971,9 @@ class SDKControlRequest(TypedDict): | SDKHookCallbackRequest | SDKControlMcpMessageRequest | SDKControlRewindFilesRequest + | SDKControlMcpReconnectRequest + | SDKControlMcpToggleRequest + | SDKControlStopTaskRequest ) diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 292944197..7821be4ed 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -107,6 +107,54 @@ async def control_protocol_generator(): return mock_transport +def _create_mock_transport_with_control_responses(): + """Create a mock transport that responds with success to all control requests. + + Useful for testing client methods that send control requests (e.g. + reconnect_mcp_server, toggle_mcp_server) without needing to special-case + each subtype in the mock. + """ + mock_transport = AsyncMock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + + written_messages: list[str] = [] + + async def mock_write(data): + written_messages.append(data) + + mock_transport.write = AsyncMock(side_effect=mock_write) + + async def control_protocol_generator(): + # Poll for control requests and respond with success to each one. + last_check = 0 + timeout_counter = 0 + while timeout_counter < 200: # Avoid infinite loop + await asyncio.sleep(0.01) + timeout_counter += 1 + + for msg_str in written_messages[last_check:]: + try: + msg = json.loads(msg_str.strip()) + if msg.get("type") == "control_request": + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "response": {}, + }, + } + except (json.JSONDecodeError, KeyError, AttributeError): + pass + last_check = len(written_messages) + + mock_transport.read_messages = control_protocol_generator + return mock_transport + + class TestClaudeSDKClientStreaming: """Test ClaudeSDKClient streaming functionality.""" @@ -467,6 +515,319 @@ async def _test(): anyio.run(_test) + def test_reconnect_mcp_server(self): + """Test reconnect_mcp_server sends correct control request.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = _create_mock_transport_with_control_responses() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.reconnect_mcp_server("my-server") + # Check that a control request was sent via write + write_calls = mock_transport.write.call_args_list + request_found = False + for call in write_calls: + data = call[0][0] + try: + msg = json.loads(data.strip()) + req = msg.get("request", {}) + if ( + msg.get("type") == "control_request" + and req.get("subtype") == "mcp_reconnect" + ): + # Verify wire format uses camelCase serverName + assert req.get("serverName") == "my-server" + request_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert request_found, "mcp_reconnect control request not found" + + anyio.run(_test) + + def test_reconnect_mcp_server_not_connected(self): + """Test reconnect_mcp_server when not connected raises error.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.reconnect_mcp_server("my-server") + + anyio.run(_test) + + def test_toggle_mcp_server(self): + """Test toggle_mcp_server sends correct control request.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = _create_mock_transport_with_control_responses() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.toggle_mcp_server("my-server", False) + # Check that a control request was sent via write + write_calls = mock_transport.write.call_args_list + request_found = False + for call in write_calls: + data = call[0][0] + try: + msg = json.loads(data.strip()) + req = msg.get("request", {}) + if ( + msg.get("type") == "control_request" + and req.get("subtype") == "mcp_toggle" + ): + # Verify wire format uses camelCase serverName + assert req.get("serverName") == "my-server" + assert req.get("enabled") is False + request_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert request_found, "mcp_toggle control request not found" + + anyio.run(_test) + + def test_toggle_mcp_server_enabled_true(self): + """Test toggle_mcp_server with enabled=True.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = _create_mock_transport_with_control_responses() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.toggle_mcp_server("other-server", True) + write_calls = mock_transport.write.call_args_list + request_found = False + for call in write_calls: + data = call[0][0] + try: + msg = json.loads(data.strip()) + req = msg.get("request", {}) + if ( + msg.get("type") == "control_request" + and req.get("subtype") == "mcp_toggle" + ): + assert req.get("serverName") == "other-server" + assert req.get("enabled") is True + request_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert request_found, "mcp_toggle control request not found" + + anyio.run(_test) + + def test_toggle_mcp_server_not_connected(self): + """Test toggle_mcp_server when not connected raises error.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.toggle_mcp_server("my-server", True) + + anyio.run(_test) + + def test_stop_task(self): + """Test stop_task sends correct control request with task_id.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = _create_mock_transport_with_control_responses() + mock_transport_class.return_value = mock_transport + + async with ClaudeSDKClient() as client: + await client.stop_task("task-abc123") + # Check that a control request was sent via write + write_calls = mock_transport.write.call_args_list + request_found = False + for call in write_calls: + data = call[0][0] + try: + msg = json.loads(data.strip()) + req = msg.get("request", {}) + if ( + msg.get("type") == "control_request" + and req.get("subtype") == "stop_task" + ): + assert req.get("task_id") == "task-abc123" + request_found = True + break + except (json.JSONDecodeError, KeyError, AttributeError): + pass + assert request_found, ( + "stop_task control request with task_id not found" + ) + + anyio.run(_test) + + def test_stop_task_not_connected(self): + """Test stop_task when not connected raises error.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.stop_task("task-abc123") + + anyio.run(_test) + + def test_get_mcp_status(self): + """Test get_mcp_status returns McpStatusResponse shape.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_transport_class: + mock_transport = AsyncMock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + mock_transport_class.return_value = mock_transport + + written_messages: list[str] = [] + + async def mock_write(data): + written_messages.append(data) + + mock_transport.write = AsyncMock(side_effect=mock_write) + + # Simulated mcp_status response matching McpServerStatus shape + mcp_status_response = { + "mcpServers": [ + { + "name": "my-http-server", + "status": "connected", + "serverInfo": { + "name": "my-http-server", + "version": "1.0.0", + }, + "config": { + "type": "http", + "url": "https://example.com/mcp", + }, + "scope": "project", + "tools": [ + { + "name": "greet", + "description": "Greet a user", + "annotations": {"readOnly": True}, + }, + {"name": "reset"}, + ], + }, + { + "name": "failed-server", + "status": "failed", + "error": "Connection refused", + }, + { + "name": "proxy-server", + "status": "needs-auth", + "config": { + "type": "claudeai-proxy", + "url": "https://claude.ai/proxy", + "id": "proxy-123", + }, + }, + ] + } + + async def control_protocol_generator(): + last_check = 0 + timeout_counter = 0 + while timeout_counter < 200: + await asyncio.sleep(0.01) + timeout_counter += 1 + + for msg_str in written_messages[last_check:]: + try: + msg = json.loads(msg_str.strip()) + if msg.get("type") == "control_request": + subtype = msg.get("request", {}).get("subtype") + if subtype == "initialize": + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "response": {}, + }, + } + elif subtype == "mcp_status": + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "response": mcp_status_response, + }, + } + except (json.JSONDecodeError, KeyError, AttributeError): + pass + last_check = len(written_messages) + + mock_transport.read_messages = control_protocol_generator + + async with ClaudeSDKClient() as client: + status = await client.get_mcp_status() + + # Verify response conforms to McpStatusResponse shape + assert "mcpServers" in status + servers = status["mcpServers"] + assert len(servers) == 3 + + # Connected server with full info + connected = servers[0] + assert connected["name"] == "my-http-server" + assert connected["status"] == "connected" + assert connected["serverInfo"]["version"] == "1.0.0" + assert connected["config"]["type"] == "http" + assert connected["config"]["url"] == "https://example.com/mcp" + assert connected["scope"] == "project" + assert len(connected["tools"]) == 2 + assert connected["tools"][0]["name"] == "greet" + assert connected["tools"][0]["annotations"]["readOnly"] is True + # Tool without optional fields + assert connected["tools"][1]["name"] == "reset" + assert "description" not in connected["tools"][1] + + # Failed server with error + failed = servers[1] + assert failed["name"] == "failed-server" + assert failed["status"] == "failed" + assert failed["error"] == "Connection refused" + + # Server with claudeai-proxy config + proxy = servers[2] + assert proxy["name"] == "proxy-server" + assert proxy["status"] == "needs-auth" + assert proxy["config"]["type"] == "claudeai-proxy" + assert proxy["config"]["id"] == "proxy-123" + + anyio.run(_test) + + def test_get_mcp_status_not_connected(self): + """Test get_mcp_status when not connected raises error.""" + + async def _test(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError, match="Not connected"): + await client.get_mcp_status() + + anyio.run(_test) + def test_client_with_options(self): """Test client initialization with options.""" diff --git a/tests/test_types.py b/tests/test_types.py index 95a88bfa9..615285f42 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -275,3 +275,98 @@ def test_post_tool_use_output_has_updated_mcp_tool_output(self): "updatedMCPToolOutput": {"result": "modified"}, } assert output["updatedMCPToolOutput"] == {"result": "modified"} + + +class TestMcpServerStatusTypes: + """Test MCP server status type definitions.""" + + def test_mcp_server_status_importable_from_package(self): + """Verify McpServerStatus and related types are exported.""" + from claude_agent_sdk import ( + McpServerConnectionStatus, # noqa: F401 + McpServerInfo, # noqa: F401 + McpServerStatus, # noqa: F401 + McpServerStatusConfig, # noqa: F401 + McpStatusResponse, # noqa: F401 + McpToolAnnotations, # noqa: F401 + McpToolInfo, # noqa: F401 + ) + + def test_mcp_server_status_connected(self): + """Test constructing a connected McpServerStatus with full fields.""" + from claude_agent_sdk import McpServerStatus + + status: McpServerStatus = { + "name": "my-server", + "status": "connected", + "serverInfo": {"name": "my-server", "version": "1.2.3"}, + "config": {"type": "http", "url": "https://example.com"}, + "scope": "project", + "tools": [ + { + "name": "greet", + "description": "Greet a user", + "annotations": { + "readOnly": True, + "destructive": False, + "openWorld": False, + }, + } + ], + } + assert status["name"] == "my-server" + assert status["status"] == "connected" + assert status["serverInfo"]["version"] == "1.2.3" + assert status["tools"][0]["annotations"]["readOnly"] is True + + def test_mcp_server_status_minimal(self): + """Test constructing a minimal McpServerStatus (only required fields).""" + from claude_agent_sdk import McpServerStatus + + status: McpServerStatus = {"name": "pending-server", "status": "pending"} + assert status["name"] == "pending-server" + assert status["status"] == "pending" + assert "error" not in status + assert "config" not in status + + def test_mcp_server_status_failed_with_error(self): + """Test McpServerStatus for a failed server includes error.""" + from claude_agent_sdk import McpServerStatus + + status: McpServerStatus = { + "name": "broken-server", + "status": "failed", + "error": "Connection refused", + } + assert status["status"] == "failed" + assert status["error"] == "Connection refused" + + def test_mcp_server_status_config_claudeai_proxy(self): + """Test McpServerStatusConfig accepts claudeai-proxy variant.""" + from claude_agent_sdk import McpServerStatus + + status: McpServerStatus = { + "name": "proxy-server", + "status": "needs-auth", + "config": { + "type": "claudeai-proxy", + "url": "https://claude.ai/proxy", + "id": "proxy-abc", + }, + } + assert status["config"]["type"] == "claudeai-proxy" + assert status["config"]["id"] == "proxy-abc" + + def test_mcp_status_response_wraps_servers(self): + """Test McpStatusResponse wraps mcpServers list.""" + from claude_agent_sdk import McpStatusResponse + + response: McpStatusResponse = { + "mcpServers": [ + {"name": "a", "status": "connected"}, + {"name": "b", "status": "disabled"}, + ] + } + assert len(response["mcpServers"]) == 2 + assert response["mcpServers"][0]["status"] == "connected" + assert response["mcpServers"][1]["status"] == "disabled"