Skip to content

Commit b89b8bc

Browse files
committed
Added group with args.
1 parent d979c43 commit b89b8bc

File tree

3 files changed

+160
-8
lines changed

3 files changed

+160
-8
lines changed

taskiq_pipelines/pipeliner.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from taskiq_pipelines.constants import CURRENT_STEP, EMPTY_PARAM_NAME, PIPELINE_DATA
2222
from taskiq_pipelines.steps import FilterStep, MapperStep, SequentialStep, parse_step
2323
from taskiq_pipelines.steps.group import GroupStep
24-
from taskiq_pipelines.task_group import Group
24+
from taskiq_pipelines.task_group import Group, GroupWithArgs
2525

2626
_ReturnType = TypeVar("_ReturnType")
2727
_FuncParams = ParamSpec("_FuncParams")
@@ -54,9 +54,9 @@ class Pipeline(Generic[_FuncParams, _ReturnType]):
5454

5555
@overload
5656
def __init__(
57-
self: "Pipeline[[], _ReturnType]",
57+
self: "Pipeline[[], Any]",
5858
broker: AsyncBroker,
59-
task: Optional[Group[Any, _ReturnType]] = None,
59+
task: None = None,
6060
) -> None: ...
6161

6262
@overload
@@ -351,10 +351,22 @@ def filter(
351351
)
352352
return self
353353

354+
@overload
355+
def group(
356+
self: "Pipeline[_FuncParams, _ReturnType]",
357+
group: GroupWithArgs[Any, _T2],
358+
) -> "Pipeline[_FuncParams, _T2]": ...
359+
360+
@overload
354361
def group(
355362
self: "Pipeline[_FuncParams, _ReturnType]",
356363
group: Group[Any, _T2],
357-
) -> "Pipeline[_FuncParams, _T2]":
364+
) -> "Pipeline[_FuncParams, _T2]": ...
365+
366+
def group(
367+
self: "Pipeline[_FuncParams, _ReturnType]",
368+
group: Group[Any, Any] | GroupWithArgs[Any, Any],
369+
) -> "Pipeline[_FuncParams, Any]":
358370
"""
359371
Add group task execution step.
360372

taskiq_pipelines/steps/group.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class GroupStepItem(pydantic.BaseModel):
4141
labels: Dict[str, Any]
4242
labels_types: Optional[Dict[str, int]] = None
4343
args: List[Any]
44+
param_name: Optional[str]
4445
kwargs: Dict[str, Any]
4546

4647
def from_message(self, message: TaskiqMessage) -> None:
@@ -54,19 +55,31 @@ def from_message(self, message: TaskiqMessage) -> None:
5455
self.args = message.args
5556
self.kwargs = message.kwargs
5657

57-
def to_message(self, task_id: str) -> TaskiqMessage:
58+
def to_message(
59+
self,
60+
task_id: str,
61+
result: Optional[TaskiqResult[Any]] = None,
62+
) -> TaskiqMessage:
5863
"""
5964
Convert this item to message.
6065
6166
:return: message
6267
"""
68+
args = self.args
69+
kwargs = self.kwargs
70+
if result:
71+
if self.param_name:
72+
kwargs[self.param_name] = result.return_value
73+
else:
74+
args = [result.return_value, *args]
75+
6376
return TaskiqMessage(
6477
task_id=task_id,
6578
task_name=self.task_name,
6679
labels=self.labels,
6780
labels_types=self.labels_types,
68-
args=self.args,
69-
kwargs=self.kwargs,
81+
args=args,
82+
kwargs=kwargs,
7083
)
7184

7285

@@ -76,6 +89,7 @@ class GroupStep(pydantic.BaseModel, AbstractStep, step_name="group"):
7689
tasks: list[GroupStepItem]
7790
skip_errors: bool
7891
check_interval: float
92+
pass_args: bool = False
7993

8094
async def act(
8195
self,
@@ -99,7 +113,13 @@ async def act(
99113
for task in self.tasks:
100114
subtask_id = broker.id_generator()
101115
ids.append(subtask_id)
102-
await broker.kick(broker.formatter.dumps(task.to_message(subtask_id)))
116+
117+
if self.pass_args:
118+
message = task.to_message(subtask_id, result)
119+
else:
120+
message = task.to_message(subtask_id, None)
121+
122+
await broker.kick(broker.formatter.dumps(message))
103123

104124
return await (
105125
wait_group_tasks.kicker()

taskiq_pipelines/task_group.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def add(
113113
labels_types=message.labels_types,
114114
args=message.args,
115115
kwargs=message.kwargs,
116+
param_name=None,
116117
),
117118
)
118119
return self
@@ -123,4 +124,123 @@ def to_step(self) -> GroupStep:
123124
tasks=list(self.tasks),
124125
skip_errors=self.skip_errors,
125126
check_interval=self.check_interval,
127+
pass_args=False,
128+
)
129+
130+
131+
class GroupWithArgs(Generic[_S, _T]):
132+
"""
133+
Group of tasks with args.
134+
135+
This class gathers multiple tasks together.
136+
They will run in parallel. The only difference
137+
with Group is that it passes the result
138+
of the previous task as the first argument
139+
to all of the tasks in the group.
140+
141+
:param skip_errors: If True, errors in one task will not affect others.
142+
"""
143+
144+
# If skip_errors is set to True,
145+
# We update first generic type to Literal[True].
146+
@overload
147+
def __init__(
148+
self: "GroupWithArgs[Literal[True], Tuple[()]]",
149+
skip_errors: Literal[True],
150+
check_interval: float = 0.1,
151+
) -> None: ...
152+
153+
@overload
154+
def __init__(
155+
self: "GroupWithArgs[Literal[False], Tuple[()]]",
156+
skip_errors: bool = False,
157+
check_interval: float = 0.1,
158+
) -> None: ...
159+
160+
def __init__(
161+
self: "GroupWithArgs[Any, Tuple[()]]",
162+
skip_errors: bool = False,
163+
check_interval: float = 0.1,
164+
) -> None:
165+
self.tasks: Tuple[GroupStepItem, ...] = ()
166+
self.skip_errors = skip_errors
167+
self.check_interval = check_interval
168+
169+
@overload
170+
def add(
171+
self: "GroupWithArgs[Literal[True], Tuple[Unpack[_Tups]]]",
172+
task: Union[
173+
AsyncKicker[_Params, Coroutine[Any, Any, _TVal]],
174+
AsyncKicker[_Params, "CoroutineType[Any, Any, _TVal]"],
175+
AsyncTaskiqDecoratedTask[_Params, Coroutine[Any, Any, _TVal]],
176+
AsyncTaskiqDecoratedTask[_Params, "CoroutineType[Any, Any, _TVal]"],
177+
],
178+
param_name: Optional[str] = None,
179+
**additional_kwargs: Any,
180+
) -> "GroupWithArgs[_S, Tuple[Unpack[_Tups], Optional[_TVal]]]": ...
181+
182+
@overload
183+
def add(
184+
self: "GroupWithArgs[Literal[False], Tuple[Unpack[_Tups]]]",
185+
task: Union[
186+
AsyncKicker[_Params, Coroutine[Any, Any, _TVal]],
187+
AsyncKicker[_Params, "CoroutineType[Any, Any, _TVal]"],
188+
AsyncTaskiqDecoratedTask[_Params, Coroutine[Any, Any, _TVal]],
189+
AsyncTaskiqDecoratedTask[_Params, "CoroutineType[Any, Any, _TVal]"],
190+
],
191+
param_name: Optional[str] = None,
192+
**additional_kwargs: Any,
193+
) -> "GroupWithArgs[_S, Tuple[Unpack[_Tups], _TVal]]": ...
194+
195+
@overload
196+
def add(
197+
self: "GroupWithArgs[Literal[True], Tuple[Unpack[_Tups]]]",
198+
task: Union[
199+
AsyncKicker[_Params, _TVal],
200+
AsyncTaskiqDecoratedTask[_Params, _TVal],
201+
],
202+
param_name: Optional[str] = None,
203+
**additional_kwargs: Any,
204+
) -> "GroupWithArgs[_S, Tuple[Unpack[_Tups], Optional[_TVal]]]": ...
205+
206+
@overload
207+
def add(
208+
self: "GroupWithArgs[Literal[False], Tuple[Unpack[_Tups]]]",
209+
task: Union[
210+
AsyncKicker[_Params, _TVal],
211+
AsyncTaskiqDecoratedTask[_Params, _TVal],
212+
],
213+
param_name: Optional[str] = None,
214+
**additional_kwargs: Any,
215+
) -> "GroupWithArgs[_S, Tuple[Unpack[_Tups], _TVal]]": ...
216+
217+
def add(
218+
self: "GroupWithArgs[Any, Any]",
219+
task: Union[AsyncKicker[_Params, Any], AsyncTaskiqDecoratedTask[_Params, Any]],
220+
param_name: Optional[str] = None,
221+
**additional_kwargs: Any,
222+
) -> "Any":
223+
"""Add task to a group."""
224+
kicker = task.kicker() if isinstance(task, AsyncTaskiqDecoratedTask) else task
225+
message = kicker._prepare_message(**additional_kwargs)
226+
self.tasks = (
227+
*self.tasks,
228+
GroupStepItem(
229+
task_name=message.task_name,
230+
labels=message.labels,
231+
labels_types=message.labels_types,
232+
args=message.args,
233+
kwargs=message.kwargs,
234+
param_name=param_name,
235+
),
236+
)
237+
return self
238+
239+
def to_step(self) -> GroupStep:
240+
"""Convert group definition to a step."""
241+
return GroupStep(
242+
tasks=list(self.tasks),
243+
skip_errors=self.skip_errors,
244+
check_interval=self.check_interval,
245+
pass_args=True,
126246
)

0 commit comments

Comments
 (0)