Skip to content

Commit 4773d57

Browse files
authored
fix: improve component.output_types decorator type hinting to support run_async methods (#9102)
* improve output_types type hinting * better name * docstrings
1 parent 6db8f0a commit 4773d57

2 files changed

Lines changed: 16 additions & 7 deletions

File tree

haystack/core/component/component.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@
7070
"""
7171

7272
import inspect
73-
from collections.abc import Callable
73+
from collections.abc import Callable, Coroutine
7474
from contextlib import contextmanager
7575
from contextvars import ContextVar
7676
from copy import deepcopy
7777
from dataclasses import dataclass
7878
from types import new_class
79-
from typing import Any, Dict, Optional, Protocol, Type, TypeVar, runtime_checkable
79+
from typing import Any, Dict, Optional, Protocol, Type, TypeVar, Union, runtime_checkable
8080

8181
from typing_extensions import ParamSpec
8282

@@ -88,8 +88,10 @@
8888

8989
logger = logging.getLogger(__name__)
9090

91-
P = ParamSpec("P")
92-
R = TypeVar("R", bound=Dict[str, Any])
91+
RunParamsT = ParamSpec("RunParamsT")
92+
SyncRunReturnT = TypeVar("SyncRunReturnT", bound=Dict[str, Any])
93+
AsyncRunReturnT = TypeVar("AsyncRunReturnT", bound=Coroutine[Any, Any, Dict[str, Any]])
94+
RunReturnT = Union[SyncRunReturnT, AsyncRunReturnT]
9395

9496

9597
@dataclass
@@ -447,12 +449,13 @@ def run(self, value: int):
447449
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
448450
)
449451

450-
def output_types(self, **types: Any) -> Callable[[Callable[P, R]], Callable[P, R]]:
452+
def output_types(
453+
self, **types: Any
454+
) -> Callable[[Callable[RunParamsT, RunReturnT]], Callable[RunParamsT, RunReturnT]]:
451455
"""
452456
Decorator factory that specifies the output types of a component.
453457
454458
Use as:
455-
456459
```python
457460
@component
458461
class MyComponent:
@@ -462,7 +465,7 @@ def run(self, value: int):
462465
```
463466
"""
464467

465-
def output_types_decorator(run_method: Callable[P, R]) -> Callable[P, R]:
468+
def output_types_decorator(run_method: Callable[RunParamsT, RunReturnT]) -> Callable[RunParamsT, RunReturnT]:
466469
"""
467470
Decorator that sets the output types of the decorated method.
468471
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
fixes:
3+
- |
4+
Improved type hinting for the `component.output_types` decorator. The type hinting for the decorator was originally
5+
introduced to avoid overshadowing the type hinting of the `run` method and allow proper static type checking.
6+
This update extends support to asynchronous `run_async` methods.

0 commit comments

Comments
 (0)