|
47 | 47 | ModelCapabilitiesOverride as _RpcModelCapabilitiesOverride, |
48 | 48 | ) |
49 | 49 | from .generated.session_events import ( |
| 50 | + AssistantMessageData, |
50 | 51 | CapabilitiesChangedData, |
51 | 52 | CommandExecuteData, |
52 | 53 | ElicitationRequestedData, |
53 | 54 | ExternalToolRequestedData, |
54 | 55 | PermissionRequest, |
55 | 56 | PermissionRequestedData, |
56 | 57 | SessionEvent, |
| 58 | + SessionErrorData, |
57 | 59 | SessionEventType, |
| 60 | + SessionIdleData, |
58 | 61 | session_event_from_dict, |
59 | 62 | ) |
60 | 63 | from .tools import Tool, ToolHandler, ToolInvocation, ToolResult |
@@ -1130,24 +1133,25 @@ async def send_and_wait( |
1130 | 1133 | Example: |
1131 | 1134 | >>> from copilot.generated.session_events import AssistantMessageData |
1132 | 1135 | >>> response = await session.send_and_wait("What is 2+2?") |
1133 | | - >>> if response and isinstance(response.data, AssistantMessageData): |
1134 | | - ... print(response.data.content) |
| 1136 | + >>> if response: |
| 1137 | + ... match response.data: |
| 1138 | + ... case AssistantMessageData() as data: |
| 1139 | + ... print(data.content) |
1135 | 1140 | """ |
1136 | 1141 | idle_event = asyncio.Event() |
1137 | 1142 | error_event: Exception | None = None |
1138 | 1143 | last_assistant_message: SessionEvent | None = None |
1139 | 1144 |
|
1140 | 1145 | def handler(event: SessionEventTypeAlias) -> None: |
1141 | 1146 | nonlocal last_assistant_message, error_event |
1142 | | - if event.type == SessionEventType.ASSISTANT_MESSAGE: |
1143 | | - last_assistant_message = event |
1144 | | - elif event.type == SessionEventType.SESSION_IDLE: |
1145 | | - idle_event.set() |
1146 | | - elif event.type == SessionEventType.SESSION_ERROR: |
1147 | | - error_event = Exception( |
1148 | | - f"Session error: {getattr(event.data, 'message', str(event.data))}" |
1149 | | - ) |
1150 | | - idle_event.set() |
| 1147 | + match event.data: |
| 1148 | + case AssistantMessageData(): |
| 1149 | + last_assistant_message = event |
| 1150 | + case SessionIdleData(): |
| 1151 | + idle_event.set() |
| 1152 | + case SessionErrorData() as data: |
| 1153 | + error_event = Exception(f"Session error: {data.message or str(data)}") |
| 1154 | + idle_event.set() |
1151 | 1155 |
|
1152 | 1156 | unsubscribe = self.on(handler) |
1153 | 1157 | try: |
@@ -1179,10 +1183,11 @@ def on(self, handler: Callable[[SessionEvent], None]) -> Callable[[], None]: |
1179 | 1183 | Example: |
1180 | 1184 | >>> from copilot.generated.session_events import AssistantMessageData, SessionErrorData |
1181 | 1185 | >>> def handle_event(event): |
1182 | | - ... if isinstance(event.data, AssistantMessageData): |
1183 | | - ... print(f"Assistant: {event.data.content}") |
1184 | | - ... elif isinstance(event.data, SessionErrorData): |
1185 | | - ... print(f"Error: {event.data.message}") |
| 1186 | + ... match event.data: |
| 1187 | + ... case AssistantMessageData() as data: |
| 1188 | + ... print(f"Assistant: {data.content}") |
| 1189 | + ... case SessionErrorData() as data: |
| 1190 | + ... print(f"Error: {data.message}") |
1186 | 1191 | >>> unsubscribe = session.on(handle_event) |
1187 | 1192 | >>> # Later, to stop receiving events: |
1188 | 1193 | >>> unsubscribe() |
@@ -1228,90 +1233,89 @@ def _handle_broadcast_event(self, event: SessionEvent) -> None: |
1228 | 1233 | Implements the protocol v3 broadcast model where tool calls and permission requests |
1229 | 1234 | are broadcast as session events to all clients. |
1230 | 1235 | """ |
1231 | | - data = event.data |
1232 | | - |
1233 | | - if isinstance(data, ExternalToolRequestedData): |
1234 | | - request_id = data.request_id |
1235 | | - tool_name = data.tool_name |
1236 | | - if not request_id or not tool_name: |
1237 | | - return |
1238 | | - |
1239 | | - handler = self._get_tool_handler(tool_name) |
1240 | | - if not handler: |
1241 | | - return # This client doesn't handle this tool; another client will. |
1242 | | - |
1243 | | - tool_call_id = data.tool_call_id or "" |
1244 | | - arguments = data.arguments |
1245 | | - tp = getattr(data, "traceparent", None) |
1246 | | - ts = getattr(data, "tracestate", None) |
1247 | | - asyncio.ensure_future( |
1248 | | - self._execute_tool_and_respond( |
1249 | | - request_id, tool_name, tool_call_id, arguments, handler, tp, ts |
| 1236 | + match event.data: |
| 1237 | + case ExternalToolRequestedData() as data: |
| 1238 | + request_id = data.request_id |
| 1239 | + tool_name = data.tool_name |
| 1240 | + if not request_id or not tool_name: |
| 1241 | + return |
| 1242 | + |
| 1243 | + handler = self._get_tool_handler(tool_name) |
| 1244 | + if not handler: |
| 1245 | + return # This client doesn't handle this tool; another client will. |
| 1246 | + |
| 1247 | + tool_call_id = data.tool_call_id or "" |
| 1248 | + arguments = data.arguments |
| 1249 | + tp = getattr(data, "traceparent", None) |
| 1250 | + ts = getattr(data, "tracestate", None) |
| 1251 | + asyncio.ensure_future( |
| 1252 | + self._execute_tool_and_respond( |
| 1253 | + request_id, tool_name, tool_call_id, arguments, handler, tp, ts |
| 1254 | + ) |
1250 | 1255 | ) |
1251 | | - ) |
1252 | 1256 |
|
1253 | | - elif isinstance(data, PermissionRequestedData): |
1254 | | - request_id = data.request_id |
1255 | | - permission_request = data.permission_request |
1256 | | - if not request_id or not permission_request: |
1257 | | - return |
| 1257 | + case PermissionRequestedData() as data: |
| 1258 | + request_id = data.request_id |
| 1259 | + permission_request = data.permission_request |
| 1260 | + if not request_id or not permission_request: |
| 1261 | + return |
1258 | 1262 |
|
1259 | | - resolved_by_hook = getattr(data, "resolved_by_hook", None) |
1260 | | - if resolved_by_hook: |
1261 | | - return # Already resolved by a permissionRequest hook; no client action needed. |
| 1263 | + resolved_by_hook = getattr(data, "resolved_by_hook", None) |
| 1264 | + if resolved_by_hook: |
| 1265 | + return # Already resolved by a permissionRequest hook; no client action needed. |
1262 | 1266 |
|
1263 | | - with self._permission_handler_lock: |
1264 | | - perm_handler = self._permission_handler |
1265 | | - if not perm_handler: |
1266 | | - return # This client doesn't handle permissions; another client will. |
| 1267 | + with self._permission_handler_lock: |
| 1268 | + perm_handler = self._permission_handler |
| 1269 | + if not perm_handler: |
| 1270 | + return # This client doesn't handle permissions; another client will. |
1267 | 1271 |
|
1268 | | - asyncio.ensure_future( |
1269 | | - self._execute_permission_and_respond(request_id, permission_request, perm_handler) |
1270 | | - ) |
| 1272 | + asyncio.ensure_future( |
| 1273 | + self._execute_permission_and_respond(request_id, permission_request, perm_handler) |
| 1274 | + ) |
1271 | 1275 |
|
1272 | | - elif isinstance(data, CommandExecuteData): |
1273 | | - request_id = data.request_id |
1274 | | - command_name = data.command_name |
1275 | | - command = data.command |
1276 | | - args = data.args |
1277 | | - if not request_id or not command_name: |
1278 | | - return |
1279 | | - asyncio.ensure_future( |
1280 | | - self._execute_command_and_respond( |
1281 | | - request_id, command_name, command or "", args or "" |
| 1276 | + case CommandExecuteData() as data: |
| 1277 | + request_id = data.request_id |
| 1278 | + command_name = data.command_name |
| 1279 | + command = data.command |
| 1280 | + args = data.args |
| 1281 | + if not request_id or not command_name: |
| 1282 | + return |
| 1283 | + asyncio.ensure_future( |
| 1284 | + self._execute_command_and_respond( |
| 1285 | + request_id, command_name, command or "", args or "" |
| 1286 | + ) |
1282 | 1287 | ) |
1283 | | - ) |
1284 | 1288 |
|
1285 | | - elif isinstance(data, ElicitationRequestedData): |
1286 | | - with self._elicitation_handler_lock: |
1287 | | - handler = self._elicitation_handler |
1288 | | - if not handler: |
1289 | | - return |
1290 | | - request_id = data.request_id |
1291 | | - if not request_id: |
1292 | | - return |
1293 | | - context: ElicitationContext = { |
1294 | | - "session_id": self.session_id, |
1295 | | - "message": data.message or "", |
1296 | | - } |
1297 | | - if data.requested_schema is not None: |
1298 | | - context["requestedSchema"] = data.requested_schema.to_dict() |
1299 | | - if data.mode is not None: |
1300 | | - context["mode"] = data.mode.value |
1301 | | - if data.elicitation_source is not None: |
1302 | | - context["elicitationSource"] = data.elicitation_source |
1303 | | - if data.url is not None: |
1304 | | - context["url"] = data.url |
1305 | | - asyncio.ensure_future(self._handle_elicitation_request(context, request_id)) |
1306 | | - |
1307 | | - elif isinstance(data, CapabilitiesChangedData): |
1308 | | - cap: SessionCapabilities = {} |
1309 | | - if data.ui is not None: |
1310 | | - ui_cap: SessionUiCapabilities = {} |
1311 | | - if data.ui.elicitation is not None: |
1312 | | - ui_cap["elicitation"] = data.ui.elicitation |
1313 | | - cap["ui"] = ui_cap |
1314 | | - self._capabilities = {**self._capabilities, **cap} |
| 1289 | + case ElicitationRequestedData() as data: |
| 1290 | + with self._elicitation_handler_lock: |
| 1291 | + handler = self._elicitation_handler |
| 1292 | + if not handler: |
| 1293 | + return |
| 1294 | + request_id = data.request_id |
| 1295 | + if not request_id: |
| 1296 | + return |
| 1297 | + context: ElicitationContext = { |
| 1298 | + "session_id": self.session_id, |
| 1299 | + "message": data.message or "", |
| 1300 | + } |
| 1301 | + if data.requested_schema is not None: |
| 1302 | + context["requestedSchema"] = data.requested_schema.to_dict() |
| 1303 | + if data.mode is not None: |
| 1304 | + context["mode"] = data.mode.value |
| 1305 | + if data.elicitation_source is not None: |
| 1306 | + context["elicitationSource"] = data.elicitation_source |
| 1307 | + if data.url is not None: |
| 1308 | + context["url"] = data.url |
| 1309 | + asyncio.ensure_future(self._handle_elicitation_request(context, request_id)) |
| 1310 | + |
| 1311 | + case CapabilitiesChangedData() as data: |
| 1312 | + cap: SessionCapabilities = {} |
| 1313 | + if data.ui is not None: |
| 1314 | + ui_cap: SessionUiCapabilities = {} |
| 1315 | + if data.ui.elicitation is not None: |
| 1316 | + ui_cap["elicitation"] = data.ui.elicitation |
| 1317 | + cap["ui"] = ui_cap |
| 1318 | + self._capabilities = {**self._capabilities, **cap} |
1315 | 1319 |
|
1316 | 1320 | async def _execute_tool_and_respond( |
1317 | 1321 | self, |
@@ -1803,8 +1807,9 @@ async def get_messages(self) -> list[SessionEvent]: |
1803 | 1807 | >>> from copilot.generated.session_events import AssistantMessageData |
1804 | 1808 | >>> events = await session.get_messages() |
1805 | 1809 | >>> for event in events: |
1806 | | - ... if isinstance(event.data, AssistantMessageData): |
1807 | | - ... print(f"Assistant: {event.data.content}") |
| 1810 | + ... match event.data: |
| 1811 | + ... case AssistantMessageData() as data: |
| 1812 | + ... print(f"Assistant: {data.content}") |
1808 | 1813 | """ |
1809 | 1814 | response = await self._client.request("session.getMessages", {"sessionId": self.session_id}) |
1810 | 1815 | # Convert dict events to SessionEvent objects |
|
0 commit comments