|
9 | 9 |
|
10 | 10 | import inspect |
11 | 11 | import logging |
12 | | -import types |
13 | 12 | from collections.abc import Awaitable, Generator |
14 | 13 | from dataclasses import dataclass |
15 | | -from typing import ( |
16 | | - TYPE_CHECKING, |
17 | | - Any, |
18 | | - Generic, |
19 | | - Protocol, |
20 | | - TypeVar, |
21 | | - Union, |
22 | | - cast, |
23 | | - get_args, |
24 | | - get_origin, |
25 | | - get_type_hints, |
26 | | - runtime_checkable, |
27 | | -) |
| 14 | +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable |
28 | 15 |
|
29 | 16 | from ..interrupt import Interrupt, InterruptException |
| 17 | +from ._type_inference import infer_event_types |
30 | 18 |
|
31 | 19 | if TYPE_CHECKING: |
32 | 20 | from ..agent import Agent |
@@ -225,7 +213,7 @@ def multi_handler(event): |
225 | 213 | resolved_event_types = self._validate_event_type_list(event_type) |
226 | 214 | elif event_type is None: |
227 | 215 | # Infer event type(s) from callback type hints |
228 | | - resolved_event_types = self._infer_event_types(callback) |
| 216 | + resolved_event_types = infer_event_types(callback) |
229 | 217 | else: |
230 | 218 | # Single event type provided explicitly |
231 | 219 | resolved_event_types = [event_type] |
@@ -261,67 +249,6 @@ def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[typ |
261 | 249 | validated.append(et) |
262 | 250 | return validated |
263 | 251 |
|
264 | | - def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: |
265 | | - """Infer the event type(s) from a callback's type hints. |
266 | | -
|
267 | | - Supports both single types and union types (A | B or Union[A, B]). |
268 | | -
|
269 | | - Args: |
270 | | - callback: The callback function to inspect. |
271 | | -
|
272 | | - Returns: |
273 | | - A list of event types inferred from the callback's first parameter type hint. |
274 | | -
|
275 | | - Raises: |
276 | | - ValueError: If the event type cannot be inferred from the callback's type hints, |
277 | | - or if a union contains None or non-BaseHookEvent types. |
278 | | - """ |
279 | | - try: |
280 | | - hints = get_type_hints(callback) |
281 | | - except Exception as e: |
282 | | - logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) |
283 | | - raise ValueError( |
284 | | - "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" |
285 | | - ) from e |
286 | | - |
287 | | - # Get the first parameter's type hint |
288 | | - sig = inspect.signature(callback) |
289 | | - params = list(sig.parameters.values()) |
290 | | - |
291 | | - if not params: |
292 | | - raise ValueError( |
293 | | - "callback has no parameters | cannot infer event type, please provide event_type explicitly" |
294 | | - ) |
295 | | - |
296 | | - first_param = params[0] |
297 | | - type_hint = hints.get(first_param.name) |
298 | | - |
299 | | - if type_hint is None: |
300 | | - raise ValueError( |
301 | | - f"parameter=<{first_param.name}> has no type hint | " |
302 | | - "cannot infer event type, please provide event_type explicitly" |
303 | | - ) |
304 | | - |
305 | | - # Check if it's a Union type (Union[A, B] or A | B) |
306 | | - origin = get_origin(type_hint) |
307 | | - if origin is Union or origin is types.UnionType: |
308 | | - event_types: list[type[TEvent]] = [] |
309 | | - for arg in get_args(type_hint): |
310 | | - if arg is type(None): |
311 | | - raise ValueError("None is not a valid event type in union") |
312 | | - if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): |
313 | | - raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") |
314 | | - event_types.append(cast(type[TEvent], arg)) |
315 | | - return event_types |
316 | | - |
317 | | - # Handle single type |
318 | | - if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): |
319 | | - return [cast(type[TEvent], type_hint)] |
320 | | - |
321 | | - raise ValueError( |
322 | | - f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" |
323 | | - ) |
324 | | - |
325 | 252 | def add_hook(self, hook: HookProvider) -> None: |
326 | 253 | """Register all callbacks from a hook provider. |
327 | 254 |
|
|
0 commit comments