55
66import pytest
77
8- from aws_durable_execution_sdk_python .config import CallbackConfig
8+ from aws_durable_execution_sdk_python .config import (
9+ CallbackConfig ,
10+ StepConfig ,
11+ WaitForCallbackConfig ,
12+ )
913from aws_durable_execution_sdk_python .exceptions import FatalError
1014from aws_durable_execution_sdk_python .identifier import OperationIdentifier
1115from aws_durable_execution_sdk_python .lambda_service import (
2226 create_callback_handler ,
2327 wait_for_callback_handler ,
2428)
29+ from aws_durable_execution_sdk_python .retries import RetryDecision
30+ from aws_durable_execution_sdk_python .serdes import SerDes
2531from aws_durable_execution_sdk_python .state import CheckpointedResult , ExecutionState
2632from aws_durable_execution_sdk_python .types import DurableContext , StepContext
2733
@@ -269,7 +275,7 @@ def test_wait_for_callback_handler_with_name_and_config():
269275 mock_callback .result .return_value = "named_callback_result"
270276 mock_context .create_callback .return_value = mock_callback
271277 mock_submitter = Mock ()
272- config = CallbackConfig ()
278+ config = WaitForCallbackConfig ()
273279
274280 result = wait_for_callback_handler (
275281 mock_context , mock_submitter , "test_callback" , config
@@ -291,7 +297,7 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id():
291297 mock_context .create_callback .return_value = mock_callback
292298 mock_submitter = Mock ()
293299
294- def capture_step_call (func , name ):
300+ def capture_step_call (func , name , config = None ):
295301 # Execute the step callable to verify submitter is called correctly
296302 step_context = Mock (spec = StepContext )
297303 func (step_context )
@@ -357,7 +363,7 @@ def test_wait_for_callback_handler_with_none_callback_id():
357363 mock_context .create_callback .return_value = mock_callback
358364 mock_submitter = Mock ()
359365
360- def execute_step (func , name ):
366+ def execute_step (func , name , config = None ):
361367 step_context = Mock (spec = StepContext )
362368 return func (step_context )
363369
@@ -378,7 +384,7 @@ def test_wait_for_callback_handler_with_empty_string_callback_id():
378384 mock_context .create_callback .return_value = mock_callback
379385 mock_submitter = Mock ()
380386
381- def execute_step (func , name ):
387+ def execute_step (func , name , config = None ):
382388 step_context = Mock (spec = StepContext )
383389 return func (step_context )
384390
@@ -426,7 +432,9 @@ def test_wait_for_callback_handler_with_unicode_names():
426432
427433 assert result == f"result_for_{ name } "
428434 expected_name = f"{ name } submitter"
429- mock_context .step .assert_called_once_with (func = ANY , name = expected_name )
435+ mock_context .step .assert_called_once_with (
436+ func = ANY , name = expected_name , config = None
437+ )
430438 mock_context .reset_mock ()
431439
432440
@@ -591,7 +599,7 @@ def failing_submitter(callback_id):
591599 msg = "Submitter failed"
592600 raise ValueError (msg )
593601
594- def step_side_effect (func , name ):
602+ def step_side_effect (func , name , config = None ):
595603 step_context = Mock (spec = StepContext )
596604 func (step_context )
597605
@@ -675,7 +683,7 @@ def test_wait_for_callback_handler_config_propagation():
675683 mock_context .create_callback .return_value = mock_callback
676684 mock_submitter = Mock ()
677685
678- config = CallbackConfig (timeout_seconds = 120 , heartbeat_timeout_seconds = 30 )
686+ config = WaitForCallbackConfig (timeout_seconds = 120 , heartbeat_timeout_seconds = 30 )
679687
680688 result = wait_for_callback_handler (
681689 mock_context , mock_submitter , "config_test" , config
@@ -687,6 +695,41 @@ def test_wait_for_callback_handler_config_propagation():
687695 )
688696
689697
698+ def test_wait_for_callback_handler_step_config_propagation ():
699+ """Test wait_for_callback_handler properly passes retry_strategy and serdes to step config."""
700+
701+ mock_context = Mock (spec = DurableContext )
702+ mock_callback = Mock ()
703+ mock_callback .callback_id = "step_config_test"
704+ mock_callback .result .return_value = "step_config_result"
705+ mock_context .create_callback .return_value = mock_callback
706+ mock_submitter = Mock ()
707+
708+ def test_retry_strategy (exception , attempt ):
709+ return RetryDecision .retry_after_delay (1 )
710+
711+ mock_serdes = Mock (spec = SerDes )
712+
713+ config = WaitForCallbackConfig (
714+ retry_strategy = test_retry_strategy , serdes = mock_serdes
715+ )
716+
717+ result = wait_for_callback_handler (
718+ mock_context , mock_submitter , "step_config_test" , config
719+ )
720+
721+ assert result == "step_config_result"
722+
723+ # Verify step was called with correct StepConfig
724+ mock_context .step .assert_called_once ()
725+ call_args = mock_context .step .call_args
726+ step_config = call_args .kwargs ["config" ]
727+
728+ assert isinstance (step_config , StepConfig )
729+ assert step_config .retry_strategy == test_retry_strategy
730+ assert step_config .serdes == mock_serdes
731+
732+
690733def test_wait_for_callback_handler_with_various_result_types ():
691734 """Test wait_for_callback_handler with various result types."""
692735 result_types = [None , True , False , 0 , math .pi , "" , "string" , [], {"key" : "value" }]
@@ -729,7 +772,7 @@ def test_callback_lifecycle_complete_flow():
729772 mock_callback .result .return_value = {"status" : "completed" , "data" : "test_data" }
730773 mock_context .create_callback .return_value = mock_callback
731774
732- config = CallbackConfig (timeout_seconds = 300 , heartbeat_timeout_seconds = 60 )
775+ config = WaitForCallbackConfig (timeout_seconds = 300 , heartbeat_timeout_seconds = 60 )
733776 callback_id = create_callback_handler (
734777 state = mock_state ,
735778 operation_identifier = OperationIdentifier ("lifecycle_callback" , None ),
@@ -742,7 +785,7 @@ def mock_submitter(cb_id):
742785 assert cb_id == "lifecycle_cb123"
743786 return "submitted"
744787
745- def execute_step (func , name ):
788+ def execute_step (func , name , config = None ):
746789 step_context = Mock (spec = StepContext )
747790 return func (step_context )
748791
@@ -862,7 +905,7 @@ def complex_submitter(callback_id):
862905 msg = "Invalid callback ID"
863906 raise ValueError (msg )
864907
865- def execute_step (func , name ):
908+ def execute_step (func , name , config ):
866909 step_context = Mock (spec = StepContext )
867910 return func (step_context )
868911
@@ -942,7 +985,9 @@ def test_callback_name_variations():
942985
943986 assert result == f"result_for_{ name } "
944987 expected_name = f"{ name } submitter" if name else "submitter"
945- mock_context .step .assert_called_once_with (func = ANY , name = expected_name )
988+ mock_context .step .assert_called_once_with (
989+ func = ANY , name = expected_name , config = None
990+ )
946991 mock_context .reset_mock ()
947992
948993
0 commit comments