1313)
1414
1515import pydantic
16- from taskiq import AsyncBroker , AsyncTaskiqTask
16+ from taskiq import AsyncBroker , AsyncTaskiqTask , TaskiqResult
1717from taskiq .decor import AsyncTaskiqDecoratedTask
1818from taskiq .kicker import AsyncKicker
1919from typing_extensions import ParamSpec
@@ -52,20 +52,44 @@ class Pipeline(Generic[_FuncParams, _ReturnType]):
5252 but it's nice to have.
5353 """
5454
55+ @overload
56+ def __init__ (
57+ self : "Pipeline[[], _ReturnType]" ,
58+ broker : AsyncBroker ,
59+ task : Optional [Group [_ReturnType ]] = None ,
60+ ) -> None : ...
61+
62+ @overload
63+ def __init__ (
64+ self ,
65+ broker : AsyncBroker ,
66+ task : Optional [
67+ Union [
68+ AsyncKicker [_FuncParams , _ReturnType ],
69+ AsyncTaskiqDecoratedTask [_FuncParams , _ReturnType ],
70+ ]
71+ ] = None ,
72+ ) -> None : ...
73+
5574 def __init__ (
5675 self ,
5776 broker : AsyncBroker ,
5877 task : Optional [
5978 Union [
6079 AsyncKicker [_FuncParams , _ReturnType ],
6180 AsyncTaskiqDecoratedTask [_FuncParams , _ReturnType ],
81+ Group [_ReturnType ],
6282 ]
6383 ] = None ,
6484 ) -> None :
6585 self .broker = broker
6686 self .steps : "List[DumpedStep]" = []
67- if task :
68- self .call_next (task )
87+ if not task :
88+ return
89+ if isinstance (task , Group ):
90+ self .group (task )
91+ return
92+ self .call_next (task )
6993
7094 @overload
7195 def call_next (
@@ -379,6 +403,48 @@ def loadb(cls, broker: AsyncBroker, pipe_data: bytes) -> "Pipeline[Any, Any]":
379403 pipe .steps = DumpedSteps .model_validate (data ) # type: ignore[assignment]
380404 return pipe
381405
406+ async def _kick_sequential (
407+ self ,
408+ step : SequentialStep ,
409+ task_id : str ,
410+ * args : Any ,
411+ ** kwargs : Any ,
412+ ) -> AsyncTaskiqTask [_ReturnType ]:
413+ kicker = (
414+ AsyncKicker (
415+ step .task_name ,
416+ broker = self .broker ,
417+ labels = step .labels ,
418+ )
419+ .with_task_id (task_id )
420+ .with_labels (
421+ ** {CURRENT_STEP : 0 , PIPELINE_DATA : self .dumpb ()}, # type: ignore
422+ )
423+ )
424+ return await kicker .kiq (* args , ** kwargs )
425+
426+ async def _kick_group (
427+ self ,
428+ group : GroupStep ,
429+ task_id : str ,
430+ ) -> AsyncTaskiqTask [Any ]:
431+ await group .act (
432+ broker = self .broker ,
433+ task_id = task_id ,
434+ step_number = 0 ,
435+ parent_task_id = "" ,
436+ pipe_data = self .dumpb (),
437+ result = TaskiqResult (
438+ is_err = False ,
439+ return_value = None ,
440+ execution_time = 0.0 ,
441+ ),
442+ )
443+ return AsyncTaskiqTask (
444+ task_id = task_id ,
445+ result_backend = self .broker .result_backend ,
446+ )
447+
382448 async def kiq (
383449 self ,
384450 * args : _FuncParams .args ,
@@ -405,20 +471,18 @@ async def kiq(
405471 self ._update_task_ids ()
406472 step = self .steps [0 ]
407473 parsed_step = parse_step (step .step_type , step .step_data )
408- if not isinstance (parsed_step , SequentialStep ):
409- raise ValueError ("First step must be sequential." )
410- kicker = (
411- AsyncKicker (
412- parsed_step .task_name ,
413- broker = self .broker ,
414- labels = parsed_step .labels ,
474+ if isinstance (parsed_step , SequentialStep ):
475+ taskiq_task = await self ._kick_sequential (
476+ parsed_step ,
477+ step .task_id ,
478+ * args ,
479+ ** kwargs ,
415480 )
416- .with_task_id (step .task_id )
417- .with_labels (
418- ** {CURRENT_STEP : 0 , PIPELINE_DATA : self .dumpb ()}, # type: ignore
419- )
420- )
421- taskiq_task = await kicker .kiq (* args , ** kwargs )
481+ elif isinstance (parsed_step , GroupStep ):
482+ taskiq_task = await self ._kick_group (parsed_step , step .task_id )
483+ else :
484+ raise ValueError ("First step must be sequential or a group." )
485+
422486 taskiq_task .task_id = self .steps [- 1 ].task_id
423487 return taskiq_task
424488
0 commit comments