-
Notifications
You must be signed in to change notification settings - Fork 978
Expand file tree
/
Copy pathclient.py
More file actions
163 lines (141 loc) · 6.51 KB
/
client.py
File metadata and controls
163 lines (141 loc) · 6.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Internal client implementation."""
import json
import os
from collections.abc import AsyncIterable, AsyncIterator
from dataclasses import asdict, replace
from typing import Any
from ..types import (
ClaudeAgentOptions,
HookEvent,
HookMatcher,
Message,
)
from .message_parser import parse_message
from .query import Query
from .transport import Transport
from .transport.subprocess_cli import SubprocessCLITransport
class InternalClient:
"""Internal client implementation."""
def __init__(self) -> None:
"""Initialize the internal client."""
def _convert_hooks_to_internal_format(
self, hooks: dict[HookEvent, list[HookMatcher]]
) -> dict[str, list[dict[str, Any]]]:
"""Convert HookMatcher format to internal Query format."""
internal_hooks: dict[str, list[dict[str, Any]]] = {}
for event, matchers in hooks.items():
internal_hooks[event] = []
for matcher in matchers:
# Convert HookMatcher to internal dict format
internal_matcher: dict[str, Any] = {
"matcher": matcher.matcher if hasattr(matcher, "matcher") else None,
"hooks": matcher.hooks if hasattr(matcher, "hooks") else [],
}
if hasattr(matcher, "timeout") and matcher.timeout is not None:
internal_matcher["timeout"] = matcher.timeout
internal_hooks[event].append(internal_matcher)
return internal_hooks
async def process_query(
self,
prompt: str | AsyncIterable[dict[str, Any]],
options: ClaudeAgentOptions,
transport: Transport | None = None,
) -> AsyncIterator[Message]:
"""Process a query through transport and Query."""
# Validate and configure permission settings (matching TypeScript SDK logic)
configured_options = options
if options.can_use_tool:
# canUseTool callback requires streaming mode (AsyncIterable prompt)
if isinstance(prompt, str):
raise ValueError(
"can_use_tool callback requires streaming mode. "
"Please provide prompt as an AsyncIterable instead of a string."
)
# canUseTool and permission_prompt_tool_name are mutually exclusive
if options.permission_prompt_tool_name:
raise ValueError(
"can_use_tool callback cannot be used with permission_prompt_tool_name. "
"Please use one or the other."
)
# Automatically set permission_prompt_tool_name to "stdio" for control protocol
configured_options = replace(options, permission_prompt_tool_name="stdio")
# Use provided transport or create subprocess transport
if transport is not None:
chosen_transport = transport
else:
chosen_transport = SubprocessCLITransport(
prompt=prompt,
options=configured_options,
)
# Connect transport
await chosen_transport.connect()
# Extract SDK MCP servers from configured options
sdk_mcp_servers = {}
if configured_options.mcp_servers and isinstance(
configured_options.mcp_servers, dict
):
for name, config in configured_options.mcp_servers.items():
if isinstance(config, dict) and config.get("type") == "sdk":
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
# Extract exclude_dynamic_sections from preset system prompt for the
# initialize request (older CLIs ignore unknown initialize fields).
exclude_dynamic_sections: bool | None = None
sp = configured_options.system_prompt
if isinstance(sp, dict) and sp.get("type") == "preset":
eds = sp.get("exclude_dynamic_sections")
if isinstance(eds, bool):
exclude_dynamic_sections = eds
# Convert agents to dict format for initialize request
agents_dict = None
if configured_options.agents:
agents_dict = {
name: {k: v for k, v in asdict(agent_def).items() if v is not None}
for name, agent_def in configured_options.agents.items()
}
# Match ClaudeSDKClient.connect() — without this, query() ignores the env var
initialize_timeout_ms = int(
os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")
)
initialize_timeout = max(initialize_timeout_ms / 1000.0, 60.0)
# Create Query to handle control protocol
# Always use streaming mode internally (matching TypeScript SDK)
# This ensures agents are always sent via initialize request
query = Query(
transport=chosen_transport,
is_streaming_mode=True, # Always streaming internally
can_use_tool=configured_options.can_use_tool,
hooks=self._convert_hooks_to_internal_format(configured_options.hooks)
if configured_options.hooks
else None,
sdk_mcp_servers=sdk_mcp_servers,
initialize_timeout=initialize_timeout,
agents=agents_dict,
exclude_dynamic_sections=exclude_dynamic_sections,
)
try:
# Start reading messages
await query.start()
# Always initialize to send agents via stdin (matching TypeScript SDK)
await query.initialize()
# Handle prompt input
if isinstance(prompt, str):
# For string prompts, write user message to stdin after initialize
# (matching TypeScript SDK behavior)
user_message = {
"type": "user",
"session_id": "",
"message": {"role": "user", "content": prompt},
"parent_tool_use_id": None,
}
await chosen_transport.write(json.dumps(user_message) + "\n")
query.spawn_task(query.wait_for_result_and_end_input())
elif isinstance(prompt, AsyncIterable):
# Stream input in background for async iterables
query.spawn_task(query.stream_input(prompt))
# Yield parsed messages, skipping unknown message types
async for data in query.receive_messages():
message = parse_message(data)
if message is not None:
yield message
finally:
await query.close()