forked from agentclientprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
123 lines (105 loc) · 4.15 KB
/
agent.py
File metadata and controls
123 lines (105 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import asyncio
import logging
from typing import Any
from acp import (
Agent,
AgentSideConnection,
AuthenticateResponse,
InitializeResponse,
LoadSessionResponse,
NewSessionResponse,
PromptResponse,
SetSessionModeResponse,
stdio_streams,
text_block,
update_agent_message,
PROTOCOL_VERSION,
)
from acp.schema import (
AgentCapabilities,
AgentMessageChunk,
AudioContentBlock,
ClientCapabilities,
EmbeddedResourceContentBlock,
HttpMcpServer,
ImageContentBlock,
Implementation,
ResourceContentBlock,
SseMcpServer,
StdioMcpServer,
TextContentBlock,
)
class ExampleAgent(Agent):
def __init__(self, conn: AgentSideConnection) -> None:
self._conn = conn
self._next_session_id = 0
self._sessions: set[str] = set()
async def _send_agent_message(self, session_id: str, content: Any) -> None:
update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content)
await self._conn.session_update(session_id, update)
async def initialize(
self,
protocol_version: int,
client_capabilities: ClientCapabilities | None = None,
client_info: Implementation | None = None,
**kwargs: Any,
) -> InitializeResponse:
logging.info("Received initialize request")
return InitializeResponse(
protocol_version=PROTOCOL_VERSION,
agent_capabilities=AgentCapabilities(),
agent_info=Implementation(name="example-agent", title="Example Agent", version="0.1.0"),
)
async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None:
logging.info("Received authenticate request %s", method_id)
return AuthenticateResponse()
async def new_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any
) -> NewSessionResponse:
logging.info("Received new session request")
session_id = str(self._next_session_id)
self._next_session_id += 1
self._sessions.add(session_id)
return NewSessionResponse(session_id=session_id, modes=None)
async def load_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any
) -> LoadSessionResponse | None:
logging.info("Received load session request %s", session_id)
self._sessions.add(session_id)
return LoadSessionResponse()
async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None:
logging.info("Received set session mode request %s -> %s", session_id, mode_id)
return SetSessionModeResponse()
async def prompt(
self,
prompt: list[
TextContentBlock
| ImageContentBlock
| AudioContentBlock
| ResourceContentBlock
| EmbeddedResourceContentBlock
],
session_id: str,
**kwargs: Any,
) -> PromptResponse:
logging.info("Received prompt request for session %s", session_id)
if session_id not in self._sessions:
self._sessions.add(session_id)
await self._send_agent_message(session_id, text_block("Client sent:"))
for block in prompt:
await self._send_agent_message(session_id, block)
return PromptResponse(stop_reason="end_turn")
async def cancel(self, session_id: str, **kwargs: Any) -> None:
logging.info("Received cancel notification for session %s", session_id)
async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
logging.info("Received extension method call: %s", method)
return {"example": "response"}
async def ext_notification(self, method: str, params: dict[str, Any]) -> None:
logging.info("Received extension notification: %s", method)
async def main() -> None:
logging.basicConfig(level=logging.INFO)
reader, writer = await stdio_streams()
AgentSideConnection(ExampleAgent, writer, reader)
await asyncio.Event().wait()
if __name__ == "__main__":
asyncio.run(main())