-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent_loop.py
More file actions
119 lines (94 loc) · 3.81 KB
/
agent_loop.py
File metadata and controls
119 lines (94 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import asyncio
from typing import Callable, Awaitable, Any
from dataclasses import dataclass
from claude_agent_sdk import (
ClaudeAgentOptions,
ClaudeSDKClient,
)
from claude_agent_sdk.types import (
AssistantMessage,
SystemMessage,
TextBlock,
ToolUseBlock,
)
from agent_chat_cli.utils.config import load_config
from agent_chat_cli.utils.enums import AgentMessageType, ContentType, ControlCommand
@dataclass
class AgentMessage:
type: AgentMessageType
data: Any
class AgentLoop:
def __init__(
self,
on_message: Callable[[AgentMessage], Awaitable[None]],
session_id: str | None = None,
) -> None:
self.config = load_config()
self.session_id = session_id
config_dict = self.config.model_dump()
if session_id:
config_dict["resume"] = session_id
self.client = ClaudeSDKClient(options=ClaudeAgentOptions(**config_dict))
self.on_message = on_message
self.query_queue: asyncio.Queue[str | ControlCommand] = asyncio.Queue()
self._running = False
self.interrupting = False
async def start(self) -> None:
await self.client.connect()
self._running = True
while self._running:
user_input = await self.query_queue.get()
if isinstance(user_input, ControlCommand):
if user_input == ControlCommand.NEW_CONVERSATION:
await self.client.disconnect()
await self.client.connect()
continue
self.interrupting = False
await self.client.query(user_input)
async for message in self.client.receive_response():
if self.interrupting:
continue
await self._handle_message(message)
await self.on_message(AgentMessage(type=AgentMessageType.RESULT, data=None))
async def _handle_message(self, message: Any) -> None:
if isinstance(message, SystemMessage):
if message.subtype == AgentMessageType.INIT.value and message.data.get(
"session_id"
):
self.session_id = message.data["session_id"]
if hasattr(message, "event"):
event = message.event # type: ignore[attr-defined]
if event.get("type") == ContentType.CONTENT_BLOCK_DELTA.value:
delta = event.get("delta", {})
if delta.get("type") == ContentType.TEXT_DELTA.value:
text_chunk = delta.get("text", "")
if text_chunk:
await self.on_message(
AgentMessage(
type=AgentMessageType.STREAM_EVENT,
data={"text": text_chunk},
)
)
elif isinstance(message, AssistantMessage):
content = []
if hasattr(message, "content"):
for block in message.content: # type: ignore[attr-defined]
if isinstance(block, TextBlock):
content.append(
{"type": ContentType.TEXT.value, "text": block.text}
)
elif isinstance(block, ToolUseBlock):
content.append(
{
"type": ContentType.TOOL_USE.value,
"id": block.id,
"name": block.name,
"input": block.input, # type: ignore[dict-item]
}
)
await self.on_message(
AgentMessage(
type=AgentMessageType.ASSISTANT,
data={"content": content},
)
)