Skip to content

Commit 907ffec

Browse files
committed
feat: add plugin interface
1 parent e80e390 commit 907ffec

10 files changed

Lines changed: 2212 additions & 204 deletions

File tree

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 209 additions & 169 deletions
Large diffs are not rendered by default.

src/aws_durable_execution_sdk_python/lambda_service.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ class OperationSubType(Enum):
9696
CHAINED_INVOKE = "ChainedInvoke"
9797

9898

99+
class InvocationStatus(Enum):
100+
SUCCEEDED = "SUCCEEDED"
101+
FAILED = "FAILED"
102+
PENDING = "PENDING"
103+
104+
# Used internally only: the invocation failed and the backend will retry
105+
RETRY = "RETRY"
106+
107+
99108
@dataclass(frozen=True)
100109
class ExecutionDetails:
101110
input_payload: str | None = None
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import datetime
2+
import logging
3+
from abc import ABC
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING
6+
7+
from aws_durable_execution_sdk_python.lambda_service import (
8+
OperationType,
9+
OperationStatus,
10+
OperationAction,
11+
OperationSubType,
12+
ErrorObject,
13+
InvocationStatus,
14+
Operation,
15+
)
16+
from aws_durable_execution_sdk_python.types import LambdaContext
17+
18+
if TYPE_CHECKING:
19+
from aws_durable_execution_sdk_python.execution import (
20+
DurableExecutionInvocationOutput,
21+
)
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
@dataclass
27+
class OperationStartInfo:
28+
operation_id: str
29+
operation_type: OperationType
30+
sub_type: OperationSubType | None = None
31+
name: str | None = None
32+
parent_id: str | None = None
33+
start_timestamp: datetime.datetime | None = None
34+
35+
36+
@dataclass
37+
class OperationEndInfo(OperationStartInfo):
38+
status: OperationStatus = OperationStatus.SUCCEEDED
39+
end_timestamp: datetime.datetime | None = None
40+
attempt: int | None = None
41+
error: ErrorObject | None = None
42+
43+
44+
@dataclass
45+
class AttemptStartInfo(OperationStartInfo):
46+
attempt: int = 1
47+
48+
49+
@dataclass
50+
class AttemptEndInfo(AttemptStartInfo):
51+
succeeded: bool | None = None
52+
end_timestamp: datetime.datetime | None = None
53+
error: ErrorObject | None = None
54+
next_attempt_delay_seconds: int | None = None
55+
56+
57+
@dataclass
58+
class InvocationStartInfo:
59+
request_id: str | None
60+
execution_arn: str | None
61+
start_timestamp: datetime.datetime | None
62+
63+
64+
@dataclass
65+
class InvocationEndInfo(InvocationStartInfo):
66+
status: InvocationStatus = InvocationStatus.SUCCEEDED
67+
end_timestamp: datetime.datetime | None = None
68+
error: ErrorObject | None = None
69+
70+
71+
@dataclass
72+
class ExecutionStartInfo(InvocationStartInfo):
73+
pass
74+
75+
76+
@dataclass
77+
class ExecutionEndInfo(ExecutionStartInfo):
78+
status: InvocationStatus = InvocationStatus.SUCCEEDED
79+
end_timestamp: datetime.datetime | None = None
80+
error: ErrorObject | None = None
81+
82+
83+
class DurableExecutionPlugin(ABC):
84+
"""Base class for plugins. Override only the methods you need."""
85+
86+
def on_execution_start(self, info: ExecutionStartInfo) -> None:
87+
pass
88+
89+
def on_execution_end(self, info: ExecutionEndInfo) -> None:
90+
pass
91+
92+
def on_invocation_start(self, info: InvocationStartInfo) -> None:
93+
pass
94+
95+
def on_invocation_end(self, info: InvocationEndInfo) -> None:
96+
pass
97+
98+
def on_operation_start(self, info: OperationStartInfo) -> None:
99+
pass
100+
101+
def on_operation_end(self, info: OperationEndInfo) -> None:
102+
pass
103+
104+
def on_operation_attempt_start(self, info: AttemptStartInfo) -> None:
105+
pass
106+
107+
def on_operation_attempt_end(self, info: AttemptEndInfo) -> None:
108+
pass
109+
110+
# Todo: further discussions required to finalize the following interface
111+
# def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass
112+
113+
114+
class PluginExecutor:
115+
def __init__(self, plugins: list[DurableExecutionPlugin] | None):
116+
self.plugins = plugins or []
117+
118+
def execute_plugins(self, info):
119+
for plugin in self.plugins:
120+
try:
121+
match info:
122+
case ExecutionEndInfo():
123+
plugin.on_execution_end(info)
124+
case InvocationEndInfo():
125+
plugin.on_invocation_end(info)
126+
case ExecutionStartInfo():
127+
plugin.on_execution_start(info)
128+
case InvocationStartInfo():
129+
plugin.on_invocation_start(info)
130+
case AttemptEndInfo():
131+
plugin.on_operation_attempt_end(info)
132+
case OperationEndInfo():
133+
plugin.on_operation_end(info)
134+
case AttemptStartInfo():
135+
plugin.on_operation_attempt_start(info)
136+
case OperationStartInfo():
137+
plugin.on_operation_start(info)
138+
case _:
139+
raise ValueError(f"Unknown info type: {type(info)}")
140+
except Exception:
141+
# log and ignore the exception
142+
logger.exception(
143+
"Plugin %s exception ignored", plugin.__class__.__name__
144+
)
145+
146+
def on_invocation_start(
147+
self,
148+
durable_execution_arn: str,
149+
context: LambdaContext | None,
150+
execution_operation: Operation | None,
151+
is_replaying: bool,
152+
) -> None:
153+
aws_request_id = context.aws_request_id if context else None
154+
start_timestamp = (
155+
execution_operation.start_timestamp if execution_operation else None
156+
)
157+
158+
if not is_replaying:
159+
self.execute_plugins(
160+
ExecutionStartInfo(
161+
request_id=aws_request_id,
162+
execution_arn=durable_execution_arn,
163+
start_timestamp=start_timestamp,
164+
)
165+
)
166+
167+
self.execute_plugins(
168+
InvocationStartInfo(
169+
request_id=aws_request_id,
170+
execution_arn=durable_execution_arn,
171+
start_timestamp=start_timestamp,
172+
)
173+
)
174+
175+
def on_invocation_end(
176+
self,
177+
durable_execution_arn: str | None,
178+
context: LambdaContext,
179+
execution_operation: Operation | None,
180+
output: "DurableExecutionInvocationOutput",
181+
) -> None:
182+
start_timestamp = (
183+
execution_operation.start_timestamp if execution_operation else None
184+
)
185+
# the actual end timestamp may be unknown because it's not checkpointed yet
186+
end_timestamp: datetime.datetime = (
187+
execution_operation.end_timestamp if execution_operation else None
188+
) or datetime.datetime.now()
189+
request_id = context.aws_request_id if context else None
190+
191+
self.execute_plugins(
192+
InvocationEndInfo(
193+
request_id=request_id,
194+
execution_arn=durable_execution_arn,
195+
start_timestamp=start_timestamp,
196+
status=output.status,
197+
end_timestamp=end_timestamp,
198+
error=output.error,
199+
)
200+
)
201+
202+
if output.status in [InvocationStatus.SUCCEEDED, InvocationStatus.FAILED]:
203+
self.execute_plugins(
204+
ExecutionEndInfo(
205+
request_id=request_id,
206+
execution_arn=durable_execution_arn,
207+
start_timestamp=start_timestamp,
208+
status=output.status,
209+
end_timestamp=end_timestamp,
210+
error=output.error,
211+
)
212+
)
213+
214+
def on_operation_action(self, update):
215+
"""Execute any registered plugins for a given operation before it is updated.
216+
217+
Args:
218+
update: the operation update that is pending checkpoint
219+
"""
220+
if update.action is OperationAction.START:
221+
self.execute_plugins(
222+
OperationStartInfo(
223+
operation_id=update.operation_id,
224+
operation_type=update.operation_type,
225+
sub_type=update.sub_type,
226+
name=update.name,
227+
parent_id=update.parent_id,
228+
start_timestamp=datetime.datetime.now(),
229+
)
230+
)
231+
232+
def on_operation_update(self, operation):
233+
"""Execute any registered plugins for a given operation after it is updated.
234+
235+
Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be
236+
checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED).
237+
238+
Args:
239+
operation: the operation is just checkpointed
240+
"""
241+
params = dict(
242+
operation_id=operation.operation_id,
243+
operation_type=operation.operation_type,
244+
sub_type=operation.sub_type,
245+
name=operation.name,
246+
parent_id=operation.parent_id,
247+
start_timestamp=operation.start_timestamp,
248+
)
249+
# todo: Python SDK doesn't submit a START update when retrying
250+
if operation.step_details and (
251+
self._is_terminal_status(operation.status)
252+
# PENDING in addition to terminal status
253+
or operation.status is OperationStatus.PENDING
254+
):
255+
self.execute_plugins(AttemptStartInfo(**params))
256+
self.execute_plugins(
257+
AttemptEndInfo(
258+
**params,
259+
end_timestamp=operation.end_timestamp,
260+
attempt=operation.step_details.attempt,
261+
succeeded=operation.status is OperationStatus.SUCCEEDED,
262+
error=operation.step_details.error,
263+
)
264+
)
265+
266+
if self._is_terminal_status(operation.status):
267+
attempt = operation.step_details.attempt if operation.step_details else None
268+
self.execute_plugins(OperationStartInfo(**params))
269+
self.execute_plugins(
270+
OperationEndInfo(
271+
**params,
272+
end_timestamp=operation.end_timestamp,
273+
status=operation.status,
274+
error=self._extract_error(operation),
275+
attempt=attempt,
276+
)
277+
)
278+
279+
@staticmethod
280+
def _extract_error(operation: Operation):
281+
if operation.step_details and operation.step_details.error:
282+
return operation.step_details.error
283+
if operation.callback_details and operation.callback_details.error:
284+
return operation.callback_details.error
285+
if operation.chained_invoke_details and operation.chained_invoke_details.error:
286+
return operation.chained_invoke_details.error
287+
if operation.context_details and operation.context_details.error:
288+
return operation.context_details.error
289+
return None
290+
291+
@staticmethod
292+
def _is_terminal_status(status):
293+
return status in [
294+
OperationStatus.SUCCEEDED,
295+
OperationStatus.FAILED,
296+
OperationStatus.TIMED_OUT,
297+
OperationStatus.CANCELLED,
298+
OperationStatus.STOPPED,
299+
]

src/aws_durable_execution_sdk_python/state.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import queue
88
import threading
99
import time
10+
from concurrent.futures import Executor
11+
from contextlib import contextmanager
1012
from dataclasses import dataclass
1113
from enum import Enum
1214
from threading import Lock
@@ -30,6 +32,9 @@
3032
OperationUpdate,
3133
StateOutput,
3234
)
35+
from aws_durable_execution_sdk_python.plugin import (
36+
PluginExecutor,
37+
)
3338
from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock
3439

3540
if TYPE_CHECKING:
@@ -229,13 +234,15 @@ def __init__(
229234
initial_checkpoint_token: str,
230235
operations: MutableMapping[str, Operation],
231236
service_client: DurableServiceClient,
237+
plugin_executor: PluginExecutor | None = None,
232238
batcher_config: CheckpointBatcherConfig | None = None,
233239
replay_status: ReplayStatus = ReplayStatus.NEW,
234240
):
235241
self.durable_execution_arn: str = durable_execution_arn
236242
self._current_checkpoint_token: str = initial_checkpoint_token
237243
self.operations: MutableMapping[str, Operation] = operations
238244
self._service_client: DurableServiceClient = service_client
245+
self._plugin_executor: PluginExecutor = plugin_executor or PluginExecutor(None)
239246
self._ordered_checkpoint_lock: OrderedLock = OrderedLock()
240247
self._operations_lock: Lock = Lock()
241248

@@ -267,7 +274,7 @@ def fetch_paginated_operations(
267274
initial_operations: list[Operation],
268275
checkpoint_token: str,
269276
next_marker: str | None,
270-
) -> None:
277+
) -> list[Operation]:
271278
"""Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe.
272279
273280
The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety.
@@ -276,6 +283,8 @@ def fetch_paginated_operations(
276283
initial_operations: initial operations to be added to ExecutionState
277284
checkpoint_token: checkpoint token used to call Durable Functions API.
278285
next_marker: a marker indicates that there are paginated operations.
286+
Returns:
287+
List of all operations fetched from the Durable Functions API
279288
280289
Raises:
281290
GetExecutionStateError: If the API call fails. The error is logged
@@ -308,6 +317,7 @@ def fetch_paginated_operations(
308317
self.operations.update(
309318
{op.operation_id: op for op in all_operations}
310319
)
320+
return all_operations
311321

312322
def get_input_payload(self) -> str | None:
313323
# It is possible that backend will not provide an execution operation
@@ -682,12 +692,18 @@ def checkpoint_batches_forever(self) -> None:
682692
current_checkpoint_token = output.checkpoint_token
683693

684694
# Fetch new operations from the API before unblocking sync waiters
685-
self.fetch_paginated_operations(
695+
updated_operations = self.fetch_paginated_operations(
686696
output.new_execution_state.operations,
687697
output.checkpoint_token,
688698
output.new_execution_state.next_marker,
689699
)
690700

701+
for update in updates:
702+
self._plugin_executor.on_operation_action(update)
703+
704+
for operation in updated_operations:
705+
self._plugin_executor.on_operation_update(operation)
706+
691707
# Signal completion for any synchronous operations
692708
for queued_op in batch:
693709
if queued_op.completion_event is not None:

0 commit comments

Comments
 (0)