Skip to content

Commit 3fe9f81

Browse files
committed
Initial groups support.
1 parent 5820bfd commit 3fe9f81

File tree

7 files changed

+114
-38
lines changed

7 files changed

+114
-38
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,14 @@ If called tasks returned `True` for some element, this element will be added in
162162

163163
After the execution you'll get a list with filtered results.
164164
You can add filters by calling `.filter` method of the pipeline.
165+
166+
167+
### Group steps
168+
169+
This step groups together multiple tasks and sends them after the previous steps.
170+
171+
To create a group you need to use `Group` class from `taskiq_pipelienes` like this:
172+
173+
```
174+
175+
```

taskiq_pipelines/abc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from typing import Any, Dict, Type
33

4-
from taskiq import AsyncBroker, TaskiqResult
4+
from taskiq import AsyncBroker, AsyncTaskiqTask, TaskiqResult
55
from typing_extensions import ClassVar
66

77

@@ -26,9 +26,9 @@ async def act(
2626
step_number: int,
2727
parent_task_id: str,
2828
task_id: str,
29-
pipe_data: str,
29+
pipe_data: bytes,
3030
result: "TaskiqResult[Any]",
31-
) -> None:
31+
) -> AsyncTaskiqTask[Any]:
3232
"""
3333
Perform pipeline action.
3434

taskiq_pipelines/pipeliner.py

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
import pydantic
16-
from taskiq import AsyncBroker, AsyncTaskiqTask
16+
from taskiq import AsyncBroker, AsyncTaskiqTask, TaskiqResult
1717
from taskiq.decor import AsyncTaskiqDecoratedTask
1818
from taskiq.kicker import AsyncKicker
1919
from typing_extensions import ParamSpec
@@ -52,20 +52,44 @@ class Pipeline(Generic[_FuncParams, _ReturnType]):
5252
but it's nice to have.
5353
"""
5454

55+
@overload
56+
def __init__(
57+
self: "Pipeline[[], _ReturnType]",
58+
broker: AsyncBroker,
59+
task: Optional[Group[_ReturnType]] = None,
60+
) -> None: ...
61+
62+
@overload
63+
def __init__(
64+
self,
65+
broker: AsyncBroker,
66+
task: Optional[
67+
Union[
68+
AsyncKicker[_FuncParams, _ReturnType],
69+
AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType],
70+
]
71+
] = None,
72+
) -> None: ...
73+
5574
def __init__(
5675
self,
5776
broker: AsyncBroker,
5877
task: Optional[
5978
Union[
6079
AsyncKicker[_FuncParams, _ReturnType],
6180
AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType],
81+
Group[_ReturnType],
6282
]
6383
] = None,
6484
) -> None:
6585
self.broker = broker
6686
self.steps: "List[DumpedStep]" = []
67-
if task:
68-
self.call_next(task)
87+
if not task:
88+
return
89+
if isinstance(task, Group):
90+
self.group(task)
91+
return
92+
self.call_next(task)
6993

7094
@overload
7195
def call_next(
@@ -379,6 +403,48 @@ def loadb(cls, broker: AsyncBroker, pipe_data: bytes) -> "Pipeline[Any, Any]":
379403
pipe.steps = DumpedSteps.model_validate(data) # type: ignore[assignment]
380404
return pipe
381405

406+
async def _kick_sequential(
407+
self,
408+
step: SequentialStep,
409+
task_id: str,
410+
*args: Any,
411+
**kwargs: Any,
412+
) -> AsyncTaskiqTask[_ReturnType]:
413+
kicker = (
414+
AsyncKicker(
415+
step.task_name,
416+
broker=self.broker,
417+
labels=step.labels,
418+
)
419+
.with_task_id(task_id)
420+
.with_labels(
421+
**{CURRENT_STEP: 0, PIPELINE_DATA: self.dumpb()}, # type: ignore
422+
)
423+
)
424+
return await kicker.kiq(*args, **kwargs)
425+
426+
async def _kick_group(
427+
self,
428+
group: GroupStep,
429+
task_id: str,
430+
) -> AsyncTaskiqTask[Any]:
431+
await group.act(
432+
broker=self.broker,
433+
task_id=task_id,
434+
step_number=0,
435+
parent_task_id="",
436+
pipe_data=self.dumpb(),
437+
result=TaskiqResult(
438+
is_err=False,
439+
return_value=None,
440+
execution_time=0.0,
441+
),
442+
)
443+
return AsyncTaskiqTask(
444+
task_id=task_id,
445+
result_backend=self.broker.result_backend,
446+
)
447+
382448
async def kiq(
383449
self,
384450
*args: _FuncParams.args,
@@ -405,20 +471,18 @@ async def kiq(
405471
self._update_task_ids()
406472
step = self.steps[0]
407473
parsed_step = parse_step(step.step_type, step.step_data)
408-
if not isinstance(parsed_step, SequentialStep):
409-
raise ValueError("First step must be sequential.")
410-
kicker = (
411-
AsyncKicker(
412-
parsed_step.task_name,
413-
broker=self.broker,
414-
labels=parsed_step.labels,
474+
if isinstance(parsed_step, SequentialStep):
475+
taskiq_task = await self._kick_sequential(
476+
parsed_step,
477+
step.task_id,
478+
*args,
479+
**kwargs,
415480
)
416-
.with_task_id(step.task_id)
417-
.with_labels(
418-
**{CURRENT_STEP: 0, PIPELINE_DATA: self.dumpb()}, # type: ignore
419-
)
420-
)
421-
taskiq_task = await kicker.kiq(*args, **kwargs)
481+
elif isinstance(parsed_step, GroupStep):
482+
taskiq_task = await self._kick_group(parsed_step, step.task_id)
483+
else:
484+
raise ValueError("First step must be sequential or a group.")
485+
422486
taskiq_task.task_id = self.steps[-1].task_id
423487
return taskiq_task
424488

taskiq_pipelines/steps/filter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Dict, Iterable, List, Optional, Union
33

44
import pydantic
5-
from taskiq import AsyncBroker, Context, TaskiqDepends, TaskiqResult
5+
from taskiq import AsyncBroker, AsyncTaskiqTask, Context, TaskiqDepends, TaskiqResult
66
from taskiq.brokers.shared_broker import async_shared_broker
77
from taskiq.decor import AsyncTaskiqDecoratedTask
88
from taskiq.kicker import AsyncKicker
@@ -88,9 +88,9 @@ async def act(
8888
step_number: int,
8989
parent_task_id: str,
9090
task_id: str,
91-
pipe_data: str,
91+
pipe_data: bytes,
9292
result: "TaskiqResult[Any]",
93-
) -> None:
93+
) -> AsyncTaskiqTask[Any]:
9494
"""
9595
Run filter action.
9696
@@ -121,7 +121,7 @@ async def act(
121121
else:
122122
task = await kicker.kiq(item, **self.additional_kwargs)
123123
sub_task_ids.append(task.task_id)
124-
await (
124+
return await (
125125
filter_tasks.kicker()
126126
.with_task_id(task_id)
127127
.with_broker(

taskiq_pipelines/steps/group.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pydantic
44
from taskiq import (
55
AsyncBroker,
6+
AsyncTaskiqTask,
67
Context,
78
TaskiqDepends,
89
TaskiqMessage,
@@ -82,9 +83,9 @@ async def act(
8283
step_number: int,
8384
parent_task_id: str,
8485
task_id: str,
85-
pipe_data: str,
86+
pipe_data: bytes,
8687
result: "TaskiqResult[Any]",
87-
) -> None:
88+
) -> AsyncTaskiqTask[Any]:
8889
"""
8990
Execute group action.
9091
@@ -100,7 +101,7 @@ async def act(
100101
ids.append(subtask_id)
101102
await broker.kick(broker.formatter.dumps(task.to_message(subtask_id)))
102103

103-
await (
104+
return await (
104105
wait_group_tasks.kicker()
105106
.with_broker(broker)
106107
.with_task_id(task_id)

taskiq_pipelines/steps/mapper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from taskiq import (
77
AsyncBroker,
88
AsyncTaskiqDecoratedTask,
9+
AsyncTaskiqTask,
910
Context,
1011
TaskiqDepends,
1112
TaskiqResult,
@@ -60,8 +61,8 @@ async def wait_tasks( # noqa: C901
6061
results: List[Any] = []
6162
for task_id in ordered_ids:
6263
result = await context.broker.result_backend.get_result(task_id)
63-
logger.warning("Found error: %s", result.error)
6464
if result.is_err:
65+
logger.warning("Found error: %s", result.error)
6566
if skip_errors:
6667
continue
6768
if none_if_errors:
@@ -92,9 +93,9 @@ async def act(
9293
step_number: int,
9394
parent_task_id: str,
9495
task_id: str,
95-
pipe_data: str,
96+
pipe_data: bytes,
9697
result: "TaskiqResult[Any]",
97-
) -> None:
98+
) -> AsyncTaskiqTask[Any]:
9899
"""
99100
Runs mapping.
100101
@@ -132,7 +133,7 @@ async def act(
132133
task = await kicker.kiq(item, **self.additional_kwargs)
133134
sub_task_ids.append(task.task_id)
134135

135-
await (
136+
return await (
136137
wait_tasks.kicker()
137138
.with_task_id(task_id)
138139
.with_broker(

taskiq_pipelines/steps/sequential.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, Optional, Union
22

33
import pydantic
4-
from taskiq import AsyncBroker, AsyncTaskiqDecoratedTask, TaskiqResult
4+
from taskiq import AsyncBroker, AsyncTaskiqDecoratedTask, AsyncTaskiqTask, TaskiqResult
55
from taskiq.kicker import AsyncKicker
66

77
from taskiq_pipelines.abc import AbstractStep
@@ -30,9 +30,9 @@ async def act(
3030
step_number: int,
3131
parent_task_id: str,
3232
task_id: str,
33-
pipe_data: str,
33+
pipe_data: bytes,
3434
result: "TaskiqResult[Any]",
35-
) -> None:
35+
) -> AsyncTaskiqTask[Any]:
3636
"""
3737
Runs next task.
3838
@@ -64,11 +64,10 @@ async def act(
6464
)
6565
if isinstance(self.param_name, str):
6666
self.additional_kwargs[self.param_name] = result.return_value
67-
await kicker.kiq(**self.additional_kwargs)
68-
elif self.param_name == EMPTY_PARAM_NAME:
69-
await kicker.kiq(**self.additional_kwargs)
70-
else:
71-
await kicker.kiq(result.return_value, **self.additional_kwargs)
67+
return await kicker.kiq(**self.additional_kwargs)
68+
if self.param_name == EMPTY_PARAM_NAME:
69+
return await kicker.kiq(**self.additional_kwargs)
70+
return await kicker.kiq(result.return_value, **self.additional_kwargs)
7271

7372
@classmethod
7473
def from_task(

0 commit comments

Comments
 (0)