Skip to content

Commit 18ddcab

Browse files
committed
Fixed types for skip_errors.
1 parent c204f49 commit 18ddcab

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,15 @@ def to_string(val: Any) -> str:
200200

201201
async def main():
202202
pipe = (
203-
Pipeline(broker).group(
204-
Group()
203+
Pipeline(broker)
204+
.group(
205+
Group(
206+
# Aborts pipeline
207+
# if any of tasks fails
208+
skip_errors=False,
209+
# How often to check for completion.
210+
check_interval=0.1,
211+
)
205212
# Here we start task that adds 1 to 1
206213
.add(add_one, 1)
207214
# Here's a task that multiplies 2 by 2

taskiq_pipelines/pipeliner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class Pipeline(Generic[_FuncParams, _ReturnType]):
5656
def __init__(
5757
self: "Pipeline[[], _ReturnType]",
5858
broker: AsyncBroker,
59-
task: Optional[Group[_ReturnType]] = None,
59+
task: Optional[Group[Any, _ReturnType]] = None,
6060
) -> None: ...
6161

6262
@overload
@@ -78,7 +78,7 @@ def __init__(
7878
Union[
7979
AsyncKicker[_FuncParams, _ReturnType],
8080
AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType],
81-
Group[_ReturnType],
81+
Group[Any, _ReturnType],
8282
]
8383
] = None,
8484
) -> None:
@@ -353,7 +353,7 @@ def filter(
353353

354354
def group(
355355
self: "Pipeline[_FuncParams, _ReturnType]",
356-
group: Group[_T2],
356+
group: Group[Any, _T2],
357357
) -> "Pipeline[_FuncParams, _T2]":
358358
"""
359359
Add group task execution step.

taskiq_pipelines/task_group.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from types import CoroutineType
2-
from typing import Any, Coroutine, Generic, Tuple, Union, overload
2+
from typing import Any, Coroutine, Generic, Literal, Optional, Tuple, Union, overload
33

44
from taskiq import AsyncTaskiqDecoratedTask
55
from taskiq.kicker import AsyncKicker
@@ -9,11 +9,13 @@
99

1010
_Tups = TypeVarTuple("_Tups")
1111
_T = TypeVar("_T")
12+
# Whether skip_errors is set to True or False
13+
_S = TypeVar("_S", bound=bool)
1214
_TVal = TypeVar("_TVal")
1315
_Params = ParamSpec("_Params")
1416

1517

16-
class Group(Generic[_T]):
18+
class Group(Generic[_S, _T]):
1719
"""
1820
Group of tasks.
1921
@@ -23,8 +25,22 @@ class Group(Generic[_T]):
2325
:param skip_errors: If True, errors in one task will not affect others.
2426
"""
2527

28+
@overload
29+
def __init__(
30+
self: "Group[Literal[True], Tuple[()]]",
31+
skip_errors: Literal[True],
32+
check_interval: float = 0.1,
33+
) -> None: ...
34+
35+
@overload
2636
def __init__(
27-
self: "Group[Tuple[()]]",
37+
self: "Group[Literal[False], Tuple[()]]",
38+
skip_errors: bool = False,
39+
check_interval: float = 0.1,
40+
) -> None: ...
41+
42+
def __init__(
43+
self: "Group[Any, Tuple[()]]",
2844
skip_errors: bool = False,
2945
check_interval: float = 0.1,
3046
) -> None:
@@ -34,7 +50,7 @@ def __init__(
3450

3551
@overload
3652
def add(
37-
self: "Group[Tuple[Unpack[_Tups]]]",
53+
self: "Group[Literal[True], Tuple[Unpack[_Tups]]]",
3854
task: Union[
3955
AsyncKicker[_Params, Coroutine[Any, Any, _TVal]],
4056
AsyncKicker[_Params, "CoroutineType[Any, Any, _TVal]"],
@@ -43,21 +59,45 @@ def add(
4359
],
4460
*args: _Params.args,
4561
**kwargs: _Params.kwargs,
46-
) -> "Group[Tuple[Unpack[_Tups], _TVal]]": ...
62+
) -> "Group[_S, Tuple[Unpack[_Tups], Optional[_TVal]]]": ...
63+
64+
@overload
65+
def add(
66+
self: "Group[Literal[False], Tuple[Unpack[_Tups]]]",
67+
task: Union[
68+
AsyncKicker[_Params, Coroutine[Any, Any, _TVal]],
69+
AsyncKicker[_Params, "CoroutineType[Any, Any, _TVal]"],
70+
AsyncTaskiqDecoratedTask[_Params, Coroutine[Any, Any, _TVal]],
71+
AsyncTaskiqDecoratedTask[_Params, "CoroutineType[Any, Any, _TVal]"],
72+
],
73+
*args: _Params.args,
74+
**kwargs: _Params.kwargs,
75+
) -> "Group[_S, Tuple[Unpack[_Tups], _TVal]]": ...
76+
77+
@overload
78+
def add(
79+
self: "Group[Literal[True], Tuple[Unpack[_Tups]]]",
80+
task: Union[
81+
AsyncKicker[_Params, _TVal],
82+
AsyncTaskiqDecoratedTask[_Params, _TVal],
83+
],
84+
*args: _Params.args,
85+
**kwargs: _Params.kwargs,
86+
) -> "Group[_S, Tuple[Unpack[_Tups], Optional[_TVal]]]": ...
4787

4888
@overload
4989
def add(
50-
self: "Group[Tuple[Unpack[_Tups]]]",
90+
self: "Group[Literal[False], Tuple[Unpack[_Tups]]]",
5191
task: Union[
5292
AsyncKicker[_Params, _TVal],
5393
AsyncTaskiqDecoratedTask[_Params, _TVal],
5494
],
5595
*args: _Params.args,
5696
**kwargs: _Params.kwargs,
57-
) -> "Group[Tuple[Unpack[_Tups], _TVal]]": ...
97+
) -> "Group[_S, Tuple[Unpack[_Tups], _TVal]]": ...
5898

5999
def add(
60-
self: "Group[Any]",
100+
self: "Group[Any, Any]",
61101
task: Union[AsyncKicker[_Params, Any], AsyncTaskiqDecoratedTask[_Params, Any]],
62102
*args: _Params.args,
63103
**kwargs: _Params.kwargs,

0 commit comments

Comments
 (0)