-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmiddleware.py
More file actions
136 lines (121 loc) · 4.39 KB
/
middleware.py
File metadata and controls
136 lines (121 loc) · 4.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from logging import getLogger
from typing import Any
import pydantic
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult
from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA
from taskiq_pipelines.exceptions import AbortPipeline
from taskiq_pipelines.pipeliner import DumpedStep
from taskiq_pipelines.steps import parse_step
logger = getLogger(__name__)
class PipelineMiddleware(TaskiqMiddleware):
"""Pipeline middleware."""
async def post_save( # noqa: PLR0911
self,
message: "TaskiqMessage",
result: "TaskiqResult[Any]",
) -> None:
"""
Handle post-execute event.
This is the heart of pipelines.
Here we decide what to do next.
If the message have pipeline
labels we can calculate our next step.
:param message: current message.
:param result: result of the execution.
"""
if result.is_err:
return
if CURRENT_STEP not in message.labels:
return
current_step_num = int(message.labels[CURRENT_STEP])
if PIPELINE_DATA not in message.labels:
logger.warning("Pipeline data not found. Execution flow is broken.")
return
pipeline_data = message.labels[PIPELINE_DATA]
parsed_data = self.broker.serializer.loadb(pipeline_data)
try:
steps_data = pydantic.TypeAdapter(list[DumpedStep]).validate_python(
parsed_data,
)
except ValueError as err:
logger.warning("Cannot parse pipeline_data: %s", err, exc_info=True)
return
if current_step_num + 1 >= len(steps_data):
logger.debug("Pipeline is completed.")
return
next_step_data = steps_data[current_step_num + 1]
try:
next_step = parse_step(
step_type=next_step_data.step_type,
step_data=next_step_data.step_data,
)
except ValueError as exc:
logger.warning("Cannot parse step data.")
logger.debug("%s", exc, exc_info=True)
return
try:
await next_step.act(
broker=self.broker,
step_number=current_step_num + 1,
parent_task_id=message.task_id,
task_id=next_step_data.task_id,
pipe_data=pipeline_data,
result=result,
)
except AbortPipeline as abort_exc:
logger.warning(
"Pipeline is aborted. Reason: %s",
abort_exc,
exc_info=True,
)
if current_step_num == len(steps_data) - 1:
return
await self.fail_pipeline(steps_data[-1].task_id)
async def on_error(
self,
message: "TaskiqMessage",
result: "TaskiqResult[Any]",
exception: BaseException,
) -> None:
"""
Handles on_error event.
:param message: current message.
:param result: execution result.
:param exception: found exception.
"""
if CURRENT_STEP not in message.labels:
return
current_step_num = int(message.labels[CURRENT_STEP])
if PIPELINE_DATA not in message.labels:
logger.warning("Pipeline data not found. Execution flow is broken.")
return
pipe_data = message.labels[PIPELINE_DATA]
try:
steps = pydantic.TypeAdapter(list[DumpedStep]).validate_json(pipe_data)
except ValueError:
return
if current_step_num == len(steps) - 1:
return
await self.fail_pipeline(steps[-1].task_id, result.error)
async def fail_pipeline(
self,
last_task_id: str,
abort: BaseException | None = None,
) -> None:
"""
This function aborts pipeline.
This is done by setting error result for
the last task in the pipeline.
:param last_task_id: id of the last task.
:param abort: caught earlier exception or default
"""
await self.broker.result_backend.set_result(
last_task_id,
TaskiqResult(
is_err=True,
return_value=None, # type: ignore
error=abort or AbortPipeline(reason="Execution aborted."),
execution_time=0,
log="Error found while executing pipeline.",
),
)