@@ -107,6 +107,54 @@ async def control_protocol_generator():
107107 return mock_transport
108108
109109
110+ def _create_mock_transport_with_control_responses ():
111+ """Create a mock transport that responds with success to all control requests.
112+
113+ Useful for testing client methods that send control requests (e.g.
114+ reconnect_mcp_server, toggle_mcp_server) without needing to special-case
115+ each subtype in the mock.
116+ """
117+ mock_transport = AsyncMock ()
118+ mock_transport .connect = AsyncMock ()
119+ mock_transport .close = AsyncMock ()
120+ mock_transport .end_input = AsyncMock ()
121+ mock_transport .is_ready = Mock (return_value = True )
122+
123+ written_messages : list [str ] = []
124+
125+ async def mock_write (data ):
126+ written_messages .append (data )
127+
128+ mock_transport .write = AsyncMock (side_effect = mock_write )
129+
130+ async def control_protocol_generator ():
131+ # Poll for control requests and respond with success to each one.
132+ last_check = 0
133+ timeout_counter = 0
134+ while timeout_counter < 200 : # Avoid infinite loop
135+ await asyncio .sleep (0.01 )
136+ timeout_counter += 1
137+
138+ for msg_str in written_messages [last_check :]:
139+ try :
140+ msg = json .loads (msg_str .strip ())
141+ if msg .get ("type" ) == "control_request" :
142+ yield {
143+ "type" : "control_response" ,
144+ "response" : {
145+ "request_id" : msg .get ("request_id" ),
146+ "subtype" : "success" ,
147+ "response" : {},
148+ },
149+ }
150+ except (json .JSONDecodeError , KeyError , AttributeError ):
151+ pass
152+ last_check = len (written_messages )
153+
154+ mock_transport .read_messages = control_protocol_generator
155+ return mock_transport
156+
157+
110158class TestClaudeSDKClientStreaming :
111159 """Test ClaudeSDKClient streaming functionality."""
112160
@@ -467,6 +515,128 @@ async def _test():
467515
468516 anyio .run (_test )
469517
518+ def test_reconnect_mcp_server (self ):
519+ """Test reconnect_mcp_server sends correct control request."""
520+
521+ async def _test ():
522+ with patch (
523+ "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
524+ ) as mock_transport_class :
525+ mock_transport = _create_mock_transport_with_control_responses ()
526+ mock_transport_class .return_value = mock_transport
527+
528+ async with ClaudeSDKClient () as client :
529+ await client .reconnect_mcp_server ("my-server" )
530+ # Check that a control request was sent via write
531+ write_calls = mock_transport .write .call_args_list
532+ request_found = False
533+ for call in write_calls :
534+ data = call [0 ][0 ]
535+ try :
536+ msg = json .loads (data .strip ())
537+ req = msg .get ("request" , {})
538+ if (
539+ msg .get ("type" ) == "control_request"
540+ and req .get ("subtype" ) == "mcp_reconnect"
541+ ):
542+ # Verify wire format uses camelCase serverName
543+ assert req .get ("serverName" ) == "my-server"
544+ request_found = True
545+ break
546+ except (json .JSONDecodeError , KeyError , AttributeError ):
547+ pass
548+ assert request_found , "mcp_reconnect control request not found"
549+
550+ anyio .run (_test )
551+
552+ def test_reconnect_mcp_server_not_connected (self ):
553+ """Test reconnect_mcp_server when not connected raises error."""
554+
555+ async def _test ():
556+ client = ClaudeSDKClient ()
557+ with pytest .raises (CLIConnectionError , match = "Not connected" ):
558+ await client .reconnect_mcp_server ("my-server" )
559+
560+ anyio .run (_test )
561+
562+ def test_toggle_mcp_server (self ):
563+ """Test toggle_mcp_server sends correct control request."""
564+
565+ async def _test ():
566+ with patch (
567+ "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
568+ ) as mock_transport_class :
569+ mock_transport = _create_mock_transport_with_control_responses ()
570+ mock_transport_class .return_value = mock_transport
571+
572+ async with ClaudeSDKClient () as client :
573+ await client .toggle_mcp_server ("my-server" , False )
574+ # Check that a control request was sent via write
575+ write_calls = mock_transport .write .call_args_list
576+ request_found = False
577+ for call in write_calls :
578+ data = call [0 ][0 ]
579+ try :
580+ msg = json .loads (data .strip ())
581+ req = msg .get ("request" , {})
582+ if (
583+ msg .get ("type" ) == "control_request"
584+ and req .get ("subtype" ) == "mcp_toggle"
585+ ):
586+ # Verify wire format uses camelCase serverName
587+ assert req .get ("serverName" ) == "my-server"
588+ assert req .get ("enabled" ) is False
589+ request_found = True
590+ break
591+ except (json .JSONDecodeError , KeyError , AttributeError ):
592+ pass
593+ assert request_found , "mcp_toggle control request not found"
594+
595+ anyio .run (_test )
596+
597+ def test_toggle_mcp_server_enabled_true (self ):
598+ """Test toggle_mcp_server with enabled=True."""
599+
600+ async def _test ():
601+ with patch (
602+ "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
603+ ) as mock_transport_class :
604+ mock_transport = _create_mock_transport_with_control_responses ()
605+ mock_transport_class .return_value = mock_transport
606+
607+ async with ClaudeSDKClient () as client :
608+ await client .toggle_mcp_server ("other-server" , True )
609+ write_calls = mock_transport .write .call_args_list
610+ request_found = False
611+ for call in write_calls :
612+ data = call [0 ][0 ]
613+ try :
614+ msg = json .loads (data .strip ())
615+ req = msg .get ("request" , {})
616+ if (
617+ msg .get ("type" ) == "control_request"
618+ and req .get ("subtype" ) == "mcp_toggle"
619+ ):
620+ assert req .get ("serverName" ) == "other-server"
621+ assert req .get ("enabled" ) is True
622+ request_found = True
623+ break
624+ except (json .JSONDecodeError , KeyError , AttributeError ):
625+ pass
626+ assert request_found , "mcp_toggle control request not found"
627+
628+ anyio .run (_test )
629+
630+ def test_toggle_mcp_server_not_connected (self ):
631+ """Test toggle_mcp_server when not connected raises error."""
632+
633+ async def _test ():
634+ client = ClaudeSDKClient ()
635+ with pytest .raises (CLIConnectionError , match = "Not connected" ):
636+ await client .toggle_mcp_server ("my-server" , True )
637+
638+ anyio .run (_test )
639+
470640 def test_client_with_options (self ):
471641 """Test client initialization with options."""
472642
0 commit comments