Skip to content

Commit 9295641

Browse files
authored
feat: add plugin interface (#371)
1 parent cc99ca6 commit 9295641

33 files changed

Lines changed: 3232 additions & 390 deletions

.github/hooks/pre-commit

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/sh
2+
3+
if hatch fmt --check; then
4+
echo "Hatch fmt check passed!"
5+
else
6+
hatch fmt
7+
echo "Error: hatch fmt modified your files. Please re-stage and commit again."
8+
exit 1
9+
fi

packages/aws-durable-execution-sdk-python-examples/examples-catalog.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,17 @@
602602
"ExecutionTimeout": 300
603603
},
604604
"path": "./src/parallel/parallel_with_named_branches.py"
605+
},
606+
{
607+
"name": "Plugin",
608+
"description": "Test plugin",
609+
"handler": "execution_with_plugin.handler",
610+
"integration": true,
611+
"durableConfig": {
612+
"RetentionPeriodInDays": 7,
613+
"ExecutionTimeout": 300
614+
},
615+
"path": "./src/plugin/execution_with_plugin.py"
605616
}
606617
]
607618
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Demonstrates handler execution without any durable operations."""
2+
3+
import logging
4+
from typing import Any
5+
6+
from aws_durable_execution_sdk_python import StepContext
7+
from aws_durable_execution_sdk_python.context import (
8+
DurableContext,
9+
durable_step,
10+
durable_with_child_context,
11+
)
12+
from aws_durable_execution_sdk_python.execution import durable_execution
13+
from aws_durable_execution_sdk_python.plugin import (
14+
DurableInstrumentationPlugin,
15+
)
16+
17+
18+
class MyPlugin(DurableInstrumentationPlugin):
19+
logger = logging.getLogger("MyPlugin")
20+
21+
def on_operation_start(self, info):
22+
self.logger.info(f"Operation started: {info}")
23+
24+
def on_operation_end(self, info):
25+
self.logger.info(f"Operation ended: {info}")
26+
27+
def on_invocation_start(self, info):
28+
self.logger.info(f"Invocation started: {info}")
29+
30+
def on_invocation_end(self, info):
31+
self.logger.info(f"Invocation ended: {info}")
32+
33+
def on_user_function_start(self, info) -> None:
34+
self.logger.info(f"User function started: {info}")
35+
36+
def on_user_function_end(self, info) -> None:
37+
self.logger.info(f"User function ended: {info}")
38+
39+
40+
@durable_step
41+
def add_numbers(_step_context: StepContext, a: int, b: int) -> int:
42+
return a + b
43+
44+
45+
@durable_with_child_context
46+
def add_numbers_in_child(child_context: DurableContext, a: int, b: int):
47+
result: int = child_context.step(
48+
add_numbers(a, b),
49+
name="add-a-and-b",
50+
)
51+
return result
52+
53+
54+
@durable_execution(plugins=[MyPlugin()])
55+
def handler(_event: Any, context: DurableContext) -> int:
56+
result: int = context.run_in_child_context(
57+
add_numbers_in_child(6, 4),
58+
name="add-6-and-4",
59+
)
60+
return context.step(
61+
add_numbers(result, 2),
62+
name="add-result-to-2",
63+
)

packages/aws-durable-execution-sdk-python-examples/template.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,24 @@
977977
"ExecutionTimeout": 300
978978
}
979979
}
980+
},
981+
"ExecutionWithPlugin": {
982+
"Type": "AWS::Serverless::Function",
983+
"Properties": {
984+
"CodeUri": "build/",
985+
"Handler": "execution_with_plugin.handler",
986+
"Description": "Test plugin",
987+
"Role": {
988+
"Fn::GetAtt": [
989+
"DurableFunctionRole",
990+
"Arn"
991+
]
992+
},
993+
"DurableConfig": {
994+
"RetentionPeriodInDays": 7,
995+
"ExecutionTimeout": 300
996+
}
997+
}
980998
}
981999
}
9821000
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Tests for step example."""
2+
3+
import pytest
4+
from aws_durable_execution_sdk_python.execution import InvocationStatus
5+
6+
from src.plugin import execution_with_plugin
7+
from test.conftest import deserialize_operation_payload
8+
9+
10+
@pytest.mark.example
11+
@pytest.mark.durable_execution(
12+
handler=execution_with_plugin.handler,
13+
lambda_function_name="Plugin",
14+
)
15+
def test_plugin(durable_runner):
16+
"""Test basic step example."""
17+
with durable_runner:
18+
result = durable_runner.run(input="{}", timeout=10)
19+
20+
assert result.status is InvocationStatus.SUCCEEDED
21+
assert deserialize_operation_payload(result.result) == 12
22+
23+
step_result = result.get_step("add-result-to-2")
24+
assert deserialize_operation_payload(step_result.result) == 12

packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
TimedSuspendExecution,
3131
)
3232
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
33-
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
33+
from aws_durable_execution_sdk_python.lambda_service import ErrorObject, OperationType
3434
from aws_durable_execution_sdk_python.operation.child import child_handler
3535

3636

@@ -428,9 +428,10 @@ def _execute_item_in_child_context(
428428
# For FLAT `child_handler` skips checkpoints, so not used.
429429
# Construct it unconditionally to keep the call simple.
430430
operation_identifier = OperationIdentifier(
431-
operation_id,
432-
executor_context._parent_id, # noqa: SLF001
433-
name,
431+
operation_id=operation_id,
432+
sub_type=self.sub_type_iteration,
433+
parent_id=executor_context._parent_id, # noqa: SLF001
434+
name=name,
434435
)
435436

436437
def run_in_child_handler() -> ResultType:

packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
ValidationError,
2424
)
2525
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
26-
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
26+
from aws_durable_execution_sdk_python.lambda_service import (
27+
OperationSubType,
28+
OperationType,
29+
)
2730
from aws_durable_execution_sdk_python.logger import Logger, LogInfo
2831
from aws_durable_execution_sdk_python.operation.callback import (
2932
CallbackOperationExecutor,
@@ -443,6 +446,7 @@ def create_callback(
443446
state=self.state,
444447
operation_identifier=OperationIdentifier(
445448
operation_id=operation_id,
449+
sub_type=OperationSubType.CALLBACK,
446450
parent_id=self._parent_id,
447451
name=name,
448452
),
@@ -485,6 +489,7 @@ def invoke(
485489
state=self.state,
486490
operation_identifier=OperationIdentifier(
487491
operation_id=operation_id,
492+
sub_type=OperationSubType.CHAINED_INVOKE,
488493
parent_id=self._parent_id,
489494
name=name,
490495
),
@@ -507,6 +512,7 @@ def map(
507512
operation_id = self._create_step_id()
508513
operation_identifier = OperationIdentifier(
509514
operation_id=operation_id,
515+
sub_type=OperationSubType.MAP,
510516
parent_id=self._parent_id,
511517
name=map_name,
512518
)
@@ -553,7 +559,10 @@ def parallel(
553559
operation_id = self._create_step_id()
554560
parallel_context = self.create_child_context(operation_id=operation_id)
555561
operation_identifier = OperationIdentifier(
556-
operation_id=operation_id, parent_id=self._parent_id, name=name
562+
operation_id=operation_id,
563+
sub_type=OperationSubType.PARALLEL,
564+
parent_id=self._parent_id,
565+
name=name,
557566
)
558567

559568
def parallel_in_child_context() -> BatchResult[T]:
@@ -606,6 +615,11 @@ def run_in_child_context(
606615
step_name: str | None = self._resolve_step_name(name, func)
607616
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
608617
operation_id = self._create_step_id()
618+
sub_type = (
619+
config.sub_type
620+
if config and config.sub_type
621+
else OperationSubType.RUN_IN_CHILD_CONTEXT
622+
)
609623

610624
is_virtual: bool = config.is_virtual if config else False
611625

@@ -621,6 +635,7 @@ def callable_with_child_context():
621635
state=self.state,
622636
operation_identifier=OperationIdentifier(
623637
operation_id=operation_id,
638+
sub_type=sub_type,
624639
parent_id=self._parent_id,
625640
name=step_name,
626641
),
@@ -646,6 +661,7 @@ def step(
646661
state=self.state,
647662
operation_identifier=OperationIdentifier(
648663
operation_id=operation_id,
664+
sub_type=OperationSubType.STEP,
649665
parent_id=self._parent_id,
650666
name=step_name,
651667
),
@@ -673,6 +689,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
673689
state=self.state,
674690
operation_identifier=OperationIdentifier(
675691
operation_id=operation_id,
692+
sub_type=OperationSubType.WAIT,
676693
parent_id=self._parent_id,
677694
name=name,
678695
),
@@ -728,6 +745,7 @@ def wait_for_condition(
728745
state=self.state,
729746
operation_identifier=OperationIdentifier(
730747
operation_id=operation_id,
748+
sub_type=OperationSubType.WAIT_FOR_CONDITION,
731749
parent_id=self._parent_id,
732750
name=name,
733751
),

0 commit comments

Comments
 (0)