-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathrouter.py
More file actions
128 lines (109 loc) · 4.23 KB
/
router.py
File metadata and controls
128 lines (109 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import annotations
from typing import Any
from pydantic import BaseModel
from ..exceptions import RequestError
from ..interfaces import Agent
from ..meta import AGENT_METHODS
from ..router import MessageRouter, Route, _resolve_handler, _warn_legacy_handler
from ..schema import (
AuthenticateRequest,
CancelNotification,
CloseSessionRequest,
ForkSessionRequest,
InitializeRequest,
ListSessionsRequest,
LoadSessionRequest,
NewSessionRequest,
PromptRequest,
ResumeSessionRequest,
SetSessionConfigOptionBooleanRequest,
SetSessionConfigOptionSelectRequest,
SetSessionModelRequest,
SetSessionModeRequest,
)
from ..utils import model_to_kwargs, normalize_result
__all__ = ["build_agent_router"]
_SET_CONFIG_OPTION_MODELS = (SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest)
def _validate_set_config_option_request(params: Any) -> BaseModel:
if isinstance(params, dict) and params.get("type") == "boolean":
return SetSessionConfigOptionBooleanRequest.model_validate(params)
return SetSessionConfigOptionSelectRequest.model_validate(params)
def _make_set_config_option_handler(agent: Agent) -> Any:
func, attr, legacy_api = _resolve_handler(agent, "set_config_option")
if func is None:
return None
async def wrapper(params: Any) -> Any:
if legacy_api:
_warn_legacy_handler(agent, attr)
request = _validate_set_config_option_request(params)
if legacy_api:
return await func(request)
return await func(**model_to_kwargs(request, _SET_CONFIG_OPTION_MODELS))
return wrapper
def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> MessageRouter:
router = MessageRouter(use_unstable_protocol=use_unstable_protocol)
router.route_request(AGENT_METHODS["initialize"], InitializeRequest, agent, "initialize")
router.route_request(AGENT_METHODS["session_new"], NewSessionRequest, agent, "new_session")
router.route_request(
AGENT_METHODS["session_load"],
LoadSessionRequest,
agent,
"load_session",
adapt_result=normalize_result,
)
router.route_request(AGENT_METHODS["session_list"], ListSessionsRequest, agent, "list_sessions")
router.route_request(
AGENT_METHODS["session_close"],
CloseSessionRequest,
agent,
"close_session",
adapt_result=normalize_result,
unstable=True,
)
router.route_request(
AGENT_METHODS["session_set_mode"],
SetSessionModeRequest,
agent,
"set_session_mode",
adapt_result=normalize_result,
)
router.route_request(AGENT_METHODS["session_prompt"], PromptRequest, agent, "prompt")
router.route_request(
AGENT_METHODS["session_set_model"],
SetSessionModelRequest,
agent,
"set_session_model",
adapt_result=normalize_result,
unstable=True,
)
router.add_route(
Route(
method=AGENT_METHODS["session_set_config_option"],
func=_make_set_config_option_handler(agent),
kind="request",
adapt_result=normalize_result,
)
)
router.route_request(
AGENT_METHODS["authenticate"],
AuthenticateRequest,
agent,
"authenticate",
adapt_result=normalize_result,
)
router.route_request(AGENT_METHODS["session_fork"], ForkSessionRequest, agent, "fork_session", unstable=True)
router.route_request(AGENT_METHODS["session_resume"], ResumeSessionRequest, agent, "resume_session", unstable=True)
router.route_notification(AGENT_METHODS["session_cancel"], CancelNotification, agent, "cancel")
@router.handle_extension_request
async def _handle_extension_request(name: str, payload: dict[str, Any]) -> Any:
ext = getattr(agent, "ext_method", None)
if ext is None:
raise RequestError.method_not_found(f"_{name}")
return await ext(name, payload)
@router.handle_extension_notification
async def _handle_extension_notification(name: str, payload: dict[str, Any]) -> None:
ext = getattr(agent, "ext_notification", None)
if ext is None:
return
await ext(name, payload)
return router