|
| 1 | +import datetime |
| 2 | +import logging |
| 3 | +from abc import ABC |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
| 7 | +from aws_durable_execution_sdk_python.lambda_service import ( |
| 8 | + OperationType, |
| 9 | + OperationStatus, |
| 10 | + OperationAction, |
| 11 | + OperationSubType, |
| 12 | + ErrorObject, |
| 13 | + InvocationStatus, |
| 14 | + Operation, |
| 15 | +) |
| 16 | +from aws_durable_execution_sdk_python.types import LambdaContext |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from aws_durable_execution_sdk_python.execution import ( |
| 20 | + DurableExecutionInvocationOutput, |
| 21 | + ) |
| 22 | + |
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +@dataclass |
| 27 | +class OperationStartInfo: |
| 28 | + operation_id: str |
| 29 | + operation_type: OperationType |
| 30 | + sub_type: OperationSubType | None = None |
| 31 | + name: str | None = None |
| 32 | + parent_id: str | None = None |
| 33 | + start_timestamp: datetime.datetime | None = None |
| 34 | + |
| 35 | + |
| 36 | +@dataclass |
| 37 | +class OperationEndInfo(OperationStartInfo): |
| 38 | + status: OperationStatus = OperationStatus.SUCCEEDED |
| 39 | + end_timestamp: datetime.datetime | None = None |
| 40 | + attempt: int | None = None |
| 41 | + error: ErrorObject | None = None |
| 42 | + |
| 43 | + |
| 44 | +@dataclass |
| 45 | +class AttemptStartInfo(OperationStartInfo): |
| 46 | + attempt: int = 1 |
| 47 | + |
| 48 | + |
| 49 | +@dataclass |
| 50 | +class AttemptEndInfo(AttemptStartInfo): |
| 51 | + succeeded: bool | None = None |
| 52 | + end_timestamp: datetime.datetime | None = None |
| 53 | + error: ErrorObject | None = None |
| 54 | + next_attempt_delay_seconds: int | None = None |
| 55 | + |
| 56 | + |
| 57 | +@dataclass |
| 58 | +class InvocationStartInfo: |
| 59 | + request_id: str | None |
| 60 | + execution_arn: str | None |
| 61 | + start_timestamp: datetime.datetime | None |
| 62 | + |
| 63 | + |
| 64 | +@dataclass |
| 65 | +class InvocationEndInfo(InvocationStartInfo): |
| 66 | + status: InvocationStatus = InvocationStatus.SUCCEEDED |
| 67 | + end_timestamp: datetime.datetime | None = None |
| 68 | + error: ErrorObject | None = None |
| 69 | + |
| 70 | + |
| 71 | +@dataclass |
| 72 | +class ExecutionStartInfo(InvocationStartInfo): |
| 73 | + pass |
| 74 | + |
| 75 | + |
| 76 | +@dataclass |
| 77 | +class ExecutionEndInfo(ExecutionStartInfo): |
| 78 | + status: InvocationStatus = InvocationStatus.SUCCEEDED |
| 79 | + end_timestamp: datetime.datetime | None = None |
| 80 | + error: ErrorObject | None = None |
| 81 | + |
| 82 | + |
| 83 | +class DurableExecutionPlugin(ABC): |
| 84 | + """Base class for plugins. Override only the methods you need.""" |
| 85 | + |
| 86 | + def on_execution_start(self, info: ExecutionStartInfo) -> None: |
| 87 | + pass |
| 88 | + |
| 89 | + def on_execution_end(self, info: ExecutionEndInfo) -> None: |
| 90 | + pass |
| 91 | + |
| 92 | + def on_invocation_start(self, info: InvocationStartInfo) -> None: |
| 93 | + pass |
| 94 | + |
| 95 | + def on_invocation_end(self, info: InvocationEndInfo) -> None: |
| 96 | + pass |
| 97 | + |
| 98 | + def on_operation_start(self, info: OperationStartInfo) -> None: |
| 99 | + pass |
| 100 | + |
| 101 | + def on_operation_end(self, info: OperationEndInfo) -> None: |
| 102 | + pass |
| 103 | + |
| 104 | + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: |
| 105 | + pass |
| 106 | + |
| 107 | + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: |
| 108 | + pass |
| 109 | + |
| 110 | + # Todo: further discussions required to finalize the following interface |
| 111 | + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass |
| 112 | + |
| 113 | + |
| 114 | +class PluginExecutor: |
| 115 | + def __init__(self, plugins: list[DurableExecutionPlugin] | None): |
| 116 | + self.plugins = plugins or [] |
| 117 | + |
| 118 | + def execute_plugins(self, info): |
| 119 | + for plugin in self.plugins: |
| 120 | + try: |
| 121 | + match info: |
| 122 | + case ExecutionEndInfo(): |
| 123 | + plugin.on_execution_end(info) |
| 124 | + case InvocationEndInfo(): |
| 125 | + plugin.on_invocation_end(info) |
| 126 | + case ExecutionStartInfo(): |
| 127 | + plugin.on_execution_start(info) |
| 128 | + case InvocationStartInfo(): |
| 129 | + plugin.on_invocation_start(info) |
| 130 | + case AttemptEndInfo(): |
| 131 | + plugin.on_operation_attempt_end(info) |
| 132 | + case OperationEndInfo(): |
| 133 | + plugin.on_operation_end(info) |
| 134 | + case AttemptStartInfo(): |
| 135 | + plugin.on_operation_attempt_start(info) |
| 136 | + case OperationStartInfo(): |
| 137 | + plugin.on_operation_start(info) |
| 138 | + case _: |
| 139 | + raise ValueError(f"Unknown info type: {type(info)}") |
| 140 | + except Exception: |
| 141 | + # log and ignore the exception |
| 142 | + logger.exception( |
| 143 | + "Plugin %s exception ignored", plugin.__class__.__name__ |
| 144 | + ) |
| 145 | + |
| 146 | + def on_invocation_start( |
| 147 | + self, |
| 148 | + durable_execution_arn: str, |
| 149 | + context: LambdaContext | None, |
| 150 | + execution_operation: Operation | None, |
| 151 | + is_replaying: bool, |
| 152 | + ) -> None: |
| 153 | + aws_request_id = context.aws_request_id if context else None |
| 154 | + start_timestamp = ( |
| 155 | + execution_operation.start_timestamp if execution_operation else None |
| 156 | + ) |
| 157 | + |
| 158 | + if not is_replaying: |
| 159 | + self.execute_plugins( |
| 160 | + ExecutionStartInfo( |
| 161 | + request_id=aws_request_id, |
| 162 | + execution_arn=durable_execution_arn, |
| 163 | + start_timestamp=start_timestamp, |
| 164 | + ) |
| 165 | + ) |
| 166 | + |
| 167 | + self.execute_plugins( |
| 168 | + InvocationStartInfo( |
| 169 | + request_id=aws_request_id, |
| 170 | + execution_arn=durable_execution_arn, |
| 171 | + start_timestamp=start_timestamp, |
| 172 | + ) |
| 173 | + ) |
| 174 | + |
| 175 | + def on_invocation_end( |
| 176 | + self, |
| 177 | + durable_execution_arn: str | None, |
| 178 | + context: LambdaContext, |
| 179 | + execution_operation: Operation | None, |
| 180 | + output: "DurableExecutionInvocationOutput", |
| 181 | + ) -> None: |
| 182 | + start_timestamp = ( |
| 183 | + execution_operation.start_timestamp if execution_operation else None |
| 184 | + ) |
| 185 | + # the actual end timestamp may be unknown because it's not checkpointed yet |
| 186 | + end_timestamp: datetime.datetime = ( |
| 187 | + execution_operation.end_timestamp if execution_operation else None |
| 188 | + ) or datetime.datetime.now() |
| 189 | + request_id = context.aws_request_id if context else None |
| 190 | + |
| 191 | + self.execute_plugins( |
| 192 | + InvocationEndInfo( |
| 193 | + request_id=request_id, |
| 194 | + execution_arn=durable_execution_arn, |
| 195 | + start_timestamp=start_timestamp, |
| 196 | + status=output.status, |
| 197 | + end_timestamp=end_timestamp, |
| 198 | + error=output.error, |
| 199 | + ) |
| 200 | + ) |
| 201 | + |
| 202 | + if output.status in [InvocationStatus.SUCCEEDED, InvocationStatus.FAILED]: |
| 203 | + self.execute_plugins( |
| 204 | + ExecutionEndInfo( |
| 205 | + request_id=request_id, |
| 206 | + execution_arn=durable_execution_arn, |
| 207 | + start_timestamp=start_timestamp, |
| 208 | + status=output.status, |
| 209 | + end_timestamp=end_timestamp, |
| 210 | + error=output.error, |
| 211 | + ) |
| 212 | + ) |
| 213 | + |
| 214 | + def on_operation_action(self, update): |
| 215 | + """Execute any registered plugins for a given operation before it is updated. |
| 216 | +
|
| 217 | + Args: |
| 218 | + update: the operation update that is pending checkpoint |
| 219 | + """ |
| 220 | + if update.action is OperationAction.START: |
| 221 | + self.execute_plugins( |
| 222 | + OperationStartInfo( |
| 223 | + operation_id=update.operation_id, |
| 224 | + operation_type=update.operation_type, |
| 225 | + sub_type=update.sub_type, |
| 226 | + name=update.name, |
| 227 | + parent_id=update.parent_id, |
| 228 | + start_timestamp=datetime.datetime.now(), |
| 229 | + ) |
| 230 | + ) |
| 231 | + |
| 232 | + def on_operation_update(self, operation): |
| 233 | + """Execute any registered plugins for a given operation after it is updated. |
| 234 | +
|
| 235 | + Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be |
| 236 | + checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED). |
| 237 | +
|
| 238 | + Args: |
| 239 | + operation: the operation is just checkpointed |
| 240 | + """ |
| 241 | + params = dict( |
| 242 | + operation_id=operation.operation_id, |
| 243 | + operation_type=operation.operation_type, |
| 244 | + sub_type=operation.sub_type, |
| 245 | + name=operation.name, |
| 246 | + parent_id=operation.parent_id, |
| 247 | + start_timestamp=operation.start_timestamp, |
| 248 | + ) |
| 249 | + # todo: Python SDK doesn't submit a START update when retrying |
| 250 | + if operation.step_details and ( |
| 251 | + self._is_terminal_status(operation.status) |
| 252 | + # PENDING in addition to terminal status |
| 253 | + or operation.status is OperationStatus.PENDING |
| 254 | + ): |
| 255 | + self.execute_plugins(AttemptStartInfo(**params)) |
| 256 | + self.execute_plugins( |
| 257 | + AttemptEndInfo( |
| 258 | + **params, |
| 259 | + end_timestamp=operation.end_timestamp, |
| 260 | + attempt=operation.step_details.attempt, |
| 261 | + succeeded=operation.status is OperationStatus.SUCCEEDED, |
| 262 | + error=operation.step_details.error, |
| 263 | + ) |
| 264 | + ) |
| 265 | + |
| 266 | + if self._is_terminal_status(operation.status): |
| 267 | + attempt = operation.step_details.attempt if operation.step_details else None |
| 268 | + self.execute_plugins(OperationStartInfo(**params)) |
| 269 | + self.execute_plugins( |
| 270 | + OperationEndInfo( |
| 271 | + **params, |
| 272 | + end_timestamp=operation.end_timestamp, |
| 273 | + status=operation.status, |
| 274 | + error=self._extract_error(operation), |
| 275 | + attempt=attempt, |
| 276 | + ) |
| 277 | + ) |
| 278 | + |
| 279 | + @staticmethod |
| 280 | + def _extract_error(operation: Operation): |
| 281 | + if operation.step_details and operation.step_details.error: |
| 282 | + return operation.step_details.error |
| 283 | + if operation.callback_details and operation.callback_details.error: |
| 284 | + return operation.callback_details.error |
| 285 | + if operation.chained_invoke_details and operation.chained_invoke_details.error: |
| 286 | + return operation.chained_invoke_details.error |
| 287 | + if operation.context_details and operation.context_details.error: |
| 288 | + return operation.context_details.error |
| 289 | + return None |
| 290 | + |
| 291 | + @staticmethod |
| 292 | + def _is_terminal_status(status): |
| 293 | + return status in [ |
| 294 | + OperationStatus.SUCCEEDED, |
| 295 | + OperationStatus.FAILED, |
| 296 | + OperationStatus.TIMED_OUT, |
| 297 | + OperationStatus.CANCELLED, |
| 298 | + OperationStatus.STOPPED, |
| 299 | + ] |
0 commit comments