diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 60c04083..be6f32b4 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -19,6 +19,7 @@ PermissionMode, PermissionResultAllow, PermissionResultDeny, + PermissionUpdate, SDKControlPermissionRequest, SDKControlRequest, SDKControlResponse, @@ -350,8 +351,12 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: context = ToolPermissionContext( signal=None, # TODO: Add abort signal support - suggestions=permission_request.get("permission_suggestions", []) - or [], + suggestions=[ + PermissionUpdate.from_dict(s) + for s in ( + permission_request.get("permission_suggestions") or [] + ) + ], tool_use_id=permission_request.get("tool_use_id"), agent_id=permission_request.get("agent_id"), blocked_path=permission_request.get("blocked_path"), diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index 202535df..1fac8475 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -169,6 +169,27 @@ def to_dict(self) -> dict[str, Any]: return result + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "PermissionUpdate": + """Construct a PermissionUpdate from the control protocol dict format (inverse of to_dict).""" + rules = None + if data.get("rules") is not None: + rules = [ + PermissionRuleValue( + tool_name=r["toolName"], + rule_content=r.get("ruleContent"), + ) + for r in data["rules"] + ] + return cls( + type=data["type"], + rules=rules, + behavior=data.get("behavior"), + mode=data.get("mode"), + directories=data.get("directories"), + destination=data.get("destination"), + ) + # Tool callback types @dataclass diff --git a/tests/test_tool_callbacks.py b/tests/test_tool_callbacks.py index f25dd5a5..f0c30dcf 100644 --- a/tests/test_tool_callbacks.py +++ b/tests/test_tool_callbacks.py @@ -97,6 +97,58 @@ async def allow_callback( response = transport.written_messages[0] assert '"behavior": "allow"' in response + @pytest.mark.asyncio + async def test_permission_callback_suggestions_roundtrip(self): + """Suggestions arrive as PermissionUpdate instances and can be echoed back.""" + from claude_agent_sdk.types import PermissionUpdate + + seen_suggestions: list[Any] = [] + + async def always_allow( + tool_name: str, input_data: dict, context: ToolPermissionContext + ) -> PermissionResultAllow: + seen_suggestions.extend(context.suggestions) + persist = [ + s for s in context.suggestions if s.destination == "localSettings" + ] + return PermissionResultAllow(updated_permissions=persist) + + transport = MockTransport() + query = Query( + transport=transport, + is_streaming_mode=True, + can_use_tool=always_allow, + hooks=None, + ) + + wire_suggestion = { + "type": "addRules", + "destination": "localSettings", + "behavior": "allow", + "rules": [{"toolName": "Bash", "ruleContent": "git status"}], + } + request = { + "type": "control_request", + "request_id": "test-roundtrip", + "request": { + "subtype": "can_use_tool", + "tool_name": "Bash", + "input": {"command": "git status"}, + "permission_suggestions": [wire_suggestion], + }, + } + + await query._handle_control_request(request) + + assert len(seen_suggestions) == 1 + assert isinstance(seen_suggestions[0], PermissionUpdate) + assert seen_suggestions[0].destination == "localSettings" + + assert len(transport.written_messages) == 1 + response = json.loads(transport.written_messages[0]) + sent = response["response"]["response"]["updatedPermissions"] + assert sent == [wire_suggestion] + @pytest.mark.asyncio async def test_permission_callback_deny(self): """Test callback that denies tool execution.""" @@ -121,7 +173,7 @@ async def deny_callback( "subtype": "can_use_tool", "tool_name": "DangerousTool", "input": {"command": "rm -rf /"}, - "permission_suggestions": ["deny"], + "permission_suggestions": [], }, } diff --git a/tests/test_types.py b/tests/test_types.py index d5f96c38..fa0f5f53 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -12,6 +12,8 @@ SubagentStartHookSpecificOutput, ) from claude_agent_sdk.types import ( + PermissionRuleValue, + PermissionUpdate, PostToolUseHookSpecificOutput, PreToolUseHookSpecificOutput, TextBlock, @@ -22,6 +24,47 @@ ) +class TestPermissionUpdate: + """Test PermissionUpdate wire-format conversion.""" + + def test_from_dict_to_dict_roundtrip_add_rules(self): + wire = { + "type": "addRules", + "destination": "localSettings", + "behavior": "allow", + "rules": [ + {"toolName": "Bash", "ruleContent": "npm *"}, + {"toolName": "Read", "ruleContent": None}, + ], + } + update = PermissionUpdate.from_dict(wire) + assert update.type == "addRules" + assert update.destination == "localSettings" + assert update.behavior == "allow" + assert update.rules == [ + PermissionRuleValue(tool_name="Bash", rule_content="npm *"), + PermissionRuleValue(tool_name="Read", rule_content=None), + ] + assert update.to_dict() == wire + + def test_from_dict_set_mode(self): + wire = {"type": "setMode", "mode": "acceptEdits", "destination": "session"} + update = PermissionUpdate.from_dict(wire) + assert update.mode == "acceptEdits" + assert update.rules is None + assert update.to_dict() == wire + + def test_from_dict_directories(self): + wire = { + "type": "addDirectories", + "directories": ["/tmp/a", "/tmp/b"], + "destination": "userSettings", + } + update = PermissionUpdate.from_dict(wire) + assert update.directories == ["/tmp/a", "/tmp/b"] + assert update.to_dict() == wire + + class TestMessageTypes: """Test message type creation and validation."""