Skip to content

Commit 5d36522

Browse files
author
Rares Polenciuc
committed
pass retry config to the Step that wraps the submitter
1 parent 87f08b2 commit 5d36522

2 files changed

Lines changed: 22 additions & 11 deletions

File tree

src/aws_durable_execution_sdk_python/operation/callback.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import TYPE_CHECKING, Any
66

7+
from aws_durable_execution_sdk_python.config import StepConfig
78
from aws_durable_execution_sdk_python.exceptions import FatalError
89
from aws_durable_execution_sdk_python.lambda_service import (
910
CallbackOptions,
@@ -16,7 +17,7 @@
1617
from aws_durable_execution_sdk_python.config import (
1718
CallbackConfig,
1819
WaitForCallbackConfig,
19-
)
20+
)
2021
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2122
from aws_durable_execution_sdk_python.state import (
2223
CheckpointedResult,
@@ -97,6 +98,13 @@ def wait_for_callback_handler(
9798
def submitter_step(step_context): # noqa: ARG001
9899
return submitter(callback.callback_id)
99100

100-
context.step(func=submitter_step, name=f"{name_with_space}submitter")
101+
if config:
102+
step_config = StepConfig(
103+
retry_strategy=config.retry_strategy,
104+
serdes=config.serdes,
105+
)
106+
context.step(func=submitter_step, name=f"{name_with_space}submitter", config=step_config)
107+
else:
108+
context.step(func=submitter_step, name=f"{name_with_space}submitter")
101109

102110
return callback.result()

tests/operation/callback_test.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
import pytest
77

8-
from aws_durable_execution_sdk_python.config import CallbackConfig
8+
from aws_durable_execution_sdk_python.config import (
9+
CallbackConfig,
10+
WaitForCallbackConfig,
11+
)
912
from aws_durable_execution_sdk_python.exceptions import FatalError
1013
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1114
from aws_durable_execution_sdk_python.lambda_service import (
@@ -269,7 +272,7 @@ def test_wait_for_callback_handler_with_name_and_config():
269272
mock_callback.result.return_value = "named_callback_result"
270273
mock_context.create_callback.return_value = mock_callback
271274
mock_submitter = Mock()
272-
config = CallbackConfig()
275+
config = WaitForCallbackConfig()
273276

274277
result = wait_for_callback_handler(
275278
mock_context, mock_submitter, "test_callback", config
@@ -291,7 +294,7 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id():
291294
mock_context.create_callback.return_value = mock_callback
292295
mock_submitter = Mock()
293296

294-
def capture_step_call(func, name):
297+
def capture_step_call(func, name, config=None):
295298
# Execute the step callable to verify submitter is called correctly
296299
step_context = Mock(spec=StepContext)
297300
func(step_context)
@@ -357,7 +360,7 @@ def test_wait_for_callback_handler_with_none_callback_id():
357360
mock_context.create_callback.return_value = mock_callback
358361
mock_submitter = Mock()
359362

360-
def execute_step(func, name):
363+
def execute_step(func, name, config=None):
361364
step_context = Mock(spec=StepContext)
362365
return func(step_context)
363366

@@ -378,7 +381,7 @@ def test_wait_for_callback_handler_with_empty_string_callback_id():
378381
mock_context.create_callback.return_value = mock_callback
379382
mock_submitter = Mock()
380383

381-
def execute_step(func, name):
384+
def execute_step(func, name, config=None):
382385
step_context = Mock(spec=StepContext)
383386
return func(step_context)
384387

@@ -591,7 +594,7 @@ def failing_submitter(callback_id):
591594
msg = "Submitter failed"
592595
raise ValueError(msg)
593596

594-
def step_side_effect(func, name):
597+
def step_side_effect(func, name, config=None):
595598
step_context = Mock(spec=StepContext)
596599
func(step_context)
597600

@@ -675,7 +678,7 @@ def test_wait_for_callback_handler_config_propagation():
675678
mock_context.create_callback.return_value = mock_callback
676679
mock_submitter = Mock()
677680

678-
config = CallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)
681+
config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)
679682

680683
result = wait_for_callback_handler(
681684
mock_context, mock_submitter, "config_test", config
@@ -729,7 +732,7 @@ def test_callback_lifecycle_complete_flow():
729732
mock_callback.result.return_value = {"status": "completed", "data": "test_data"}
730733
mock_context.create_callback.return_value = mock_callback
731734

732-
config = CallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
735+
config = WaitForCallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
733736
callback_id = create_callback_handler(
734737
state=mock_state,
735738
operation_identifier=OperationIdentifier("lifecycle_callback", None),
@@ -742,7 +745,7 @@ def mock_submitter(cb_id):
742745
assert cb_id == "lifecycle_cb123"
743746
return "submitted"
744747

745-
def execute_step(func, name):
748+
def execute_step(func, name, config=None):
746749
step_context = Mock(spec=StepContext)
747750
return func(step_context)
748751

0 commit comments

Comments
 (0)