Skip to content

Commit 9c47087

Browse files
Rares PolenciucFullyTyped
authored andcommitted
feat: pass retry config to the Step that wraps the submitter
1 parent 5d29be8 commit 9c47087

2 files changed

Lines changed: 69 additions & 13 deletions

File tree

src/aws_durable_execution_sdk_python/operation/callback.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import TYPE_CHECKING, Any
66

7+
from aws_durable_execution_sdk_python.config import StepConfig
78
from aws_durable_execution_sdk_python.exceptions import FatalError
89
from aws_durable_execution_sdk_python.lambda_service import (
910
CallbackOptions,
@@ -97,6 +98,16 @@ def wait_for_callback_handler(
9798
def submitter_step(step_context): # noqa: ARG001
9899
return submitter(callback.callback_id)
99100

100-
context.step(func=submitter_step, name=f"{name_with_space}submitter")
101+
step_config = (
102+
StepConfig(
103+
retry_strategy=config.retry_strategy,
104+
serdes=config.serdes,
105+
)
106+
if config
107+
else None
108+
)
109+
context.step(
110+
func=submitter_step, name=f"{name_with_space}submitter", config=step_config
111+
)
101112

102113
return callback.result()

tests/operation/callback_test.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import pytest
77

8-
from aws_durable_execution_sdk_python.config import CallbackConfig
8+
from aws_durable_execution_sdk_python.config import (
9+
CallbackConfig,
10+
StepConfig,
11+
WaitForCallbackConfig,
12+
)
913
from aws_durable_execution_sdk_python.exceptions import FatalError
1014
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1115
from aws_durable_execution_sdk_python.lambda_service import (
@@ -22,6 +26,8 @@
2226
create_callback_handler,
2327
wait_for_callback_handler,
2428
)
29+
from aws_durable_execution_sdk_python.retries import RetryDecision
30+
from aws_durable_execution_sdk_python.serdes import SerDes
2531
from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState
2632
from aws_durable_execution_sdk_python.types import DurableContext, StepContext
2733

@@ -269,7 +275,7 @@ def test_wait_for_callback_handler_with_name_and_config():
269275
mock_callback.result.return_value = "named_callback_result"
270276
mock_context.create_callback.return_value = mock_callback
271277
mock_submitter = Mock()
272-
config = CallbackConfig()
278+
config = WaitForCallbackConfig()
273279

274280
result = wait_for_callback_handler(
275281
mock_context, mock_submitter, "test_callback", config
@@ -291,7 +297,7 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id():
291297
mock_context.create_callback.return_value = mock_callback
292298
mock_submitter = Mock()
293299

294-
def capture_step_call(func, name):
300+
def capture_step_call(func, name, config=None):
295301
# Execute the step callable to verify submitter is called correctly
296302
step_context = Mock(spec=StepContext)
297303
func(step_context)
@@ -357,7 +363,7 @@ def test_wait_for_callback_handler_with_none_callback_id():
357363
mock_context.create_callback.return_value = mock_callback
358364
mock_submitter = Mock()
359365

360-
def execute_step(func, name):
366+
def execute_step(func, name, config=None):
361367
step_context = Mock(spec=StepContext)
362368
return func(step_context)
363369

@@ -378,7 +384,7 @@ def test_wait_for_callback_handler_with_empty_string_callback_id():
378384
mock_context.create_callback.return_value = mock_callback
379385
mock_submitter = Mock()
380386

381-
def execute_step(func, name):
387+
def execute_step(func, name, config=None):
382388
step_context = Mock(spec=StepContext)
383389
return func(step_context)
384390

@@ -426,7 +432,9 @@ def test_wait_for_callback_handler_with_unicode_names():
426432

427433
assert result == f"result_for_{name}"
428434
expected_name = f"{name} submitter"
429-
mock_context.step.assert_called_once_with(func=ANY, name=expected_name)
435+
mock_context.step.assert_called_once_with(
436+
func=ANY, name=expected_name, config=None
437+
)
430438
mock_context.reset_mock()
431439

432440

@@ -591,7 +599,7 @@ def failing_submitter(callback_id):
591599
msg = "Submitter failed"
592600
raise ValueError(msg)
593601

594-
def step_side_effect(func, name):
602+
def step_side_effect(func, name, config=None):
595603
step_context = Mock(spec=StepContext)
596604
func(step_context)
597605

@@ -675,7 +683,7 @@ def test_wait_for_callback_handler_config_propagation():
675683
mock_context.create_callback.return_value = mock_callback
676684
mock_submitter = Mock()
677685

678-
config = CallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)
686+
config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)
679687

680688
result = wait_for_callback_handler(
681689
mock_context, mock_submitter, "config_test", config
@@ -687,6 +695,41 @@ def test_wait_for_callback_handler_config_propagation():
687695
)
688696

689697

698+
def test_wait_for_callback_handler_step_config_propagation():
699+
"""Test wait_for_callback_handler properly passes retry_strategy and serdes to step config."""
700+
701+
mock_context = Mock(spec=DurableContext)
702+
mock_callback = Mock()
703+
mock_callback.callback_id = "step_config_test"
704+
mock_callback.result.return_value = "step_config_result"
705+
mock_context.create_callback.return_value = mock_callback
706+
mock_submitter = Mock()
707+
708+
def test_retry_strategy(exception, attempt):
709+
return RetryDecision.retry_after_delay(1)
710+
711+
mock_serdes = Mock(spec=SerDes)
712+
713+
config = WaitForCallbackConfig(
714+
retry_strategy=test_retry_strategy, serdes=mock_serdes
715+
)
716+
717+
result = wait_for_callback_handler(
718+
mock_context, mock_submitter, "step_config_test", config
719+
)
720+
721+
assert result == "step_config_result"
722+
723+
# Verify step was called with correct StepConfig
724+
mock_context.step.assert_called_once()
725+
call_args = mock_context.step.call_args
726+
step_config = call_args.kwargs["config"]
727+
728+
assert isinstance(step_config, StepConfig)
729+
assert step_config.retry_strategy == test_retry_strategy
730+
assert step_config.serdes == mock_serdes
731+
732+
690733
def test_wait_for_callback_handler_with_various_result_types():
691734
"""Test wait_for_callback_handler with various result types."""
692735
result_types = [None, True, False, 0, math.pi, "", "string", [], {"key": "value"}]
@@ -729,7 +772,7 @@ def test_callback_lifecycle_complete_flow():
729772
mock_callback.result.return_value = {"status": "completed", "data": "test_data"}
730773
mock_context.create_callback.return_value = mock_callback
731774

732-
config = CallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
775+
config = WaitForCallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
733776
callback_id = create_callback_handler(
734777
state=mock_state,
735778
operation_identifier=OperationIdentifier("lifecycle_callback", None),
@@ -742,7 +785,7 @@ def mock_submitter(cb_id):
742785
assert cb_id == "lifecycle_cb123"
743786
return "submitted"
744787

745-
def execute_step(func, name):
788+
def execute_step(func, name, config=None):
746789
step_context = Mock(spec=StepContext)
747790
return func(step_context)
748791

@@ -862,7 +905,7 @@ def complex_submitter(callback_id):
862905
msg = "Invalid callback ID"
863906
raise ValueError(msg)
864907

865-
def execute_step(func, name):
908+
def execute_step(func, name, config):
866909
step_context = Mock(spec=StepContext)
867910
return func(step_context)
868911

@@ -942,7 +985,9 @@ def test_callback_name_variations():
942985

943986
assert result == f"result_for_{name}"
944987
expected_name = f"{name} submitter" if name else "submitter"
945-
mock_context.step.assert_called_once_with(func=ANY, name=expected_name)
988+
mock_context.step.assert_called_once_with(
989+
func=ANY, name=expected_name, config=None
990+
)
946991
mock_context.reset_mock()
947992

948993

0 commit comments

Comments
 (0)