@@ -980,168 +980,176 @@ def get_checkpoint_result(self, operation_id):
980980
981981def test_map_handler_serializes_batch_result ():
982982 """Verify map_handler serializes BatchResult at parent level."""
983- with patch (
984- "aws_durable_execution_sdk_python.serdes.serialize"
985- ) as mock_serdes_serialize :
986- mock_serdes_serialize .return_value = '"serialized"'
987- importlib .reload (child )
983+ try :
984+ with patch (
985+ "aws_durable_execution_sdk_python.serdes.serialize"
986+ ) as mock_serdes_serialize :
987+ mock_serdes_serialize .return_value = '"serialized"'
988+ importlib .reload (child )
989+
990+ parent_checkpoint = Mock ()
991+ parent_checkpoint .is_succeeded .return_value = False
992+ parent_checkpoint .is_failed .return_value = False
993+ parent_checkpoint .is_existent .return_value = False
994+ parent_checkpoint .is_replay_children .return_value = False
995+
996+ child_checkpoint = Mock ()
997+ child_checkpoint .is_succeeded .return_value = False
998+ child_checkpoint .is_failed .return_value = False
999+ child_checkpoint .is_existent .return_value = False
1000+ child_checkpoint .is_replay_children .return_value = False
1001+
1002+ def get_checkpoint (op_id ):
1003+ return child_checkpoint if op_id .startswith ("child-" ) else parent_checkpoint
1004+
1005+ mock_state = Mock ()
1006+ mock_state .durable_execution_arn = "arn:test"
1007+ mock_state .get_checkpoint_result = Mock (side_effect = get_checkpoint )
1008+ mock_state .create_checkpoint = Mock ()
1009+
1010+ context_map = {}
1011+
1012+ def create_id (self , i ):
1013+ ctx_id = id (self )
1014+ if ctx_id not in context_map :
1015+ context_map [ctx_id ] = []
1016+ context_map [ctx_id ].append (i )
1017+ return (
1018+ "parent"
1019+ if len (context_map ) == 1 and len (context_map [ctx_id ]) == 1
1020+ else f"child-{ i } "
1021+ )
9881022
989- parent_checkpoint = Mock ()
990- parent_checkpoint .is_succeeded .return_value = False
991- parent_checkpoint .is_failed .return_value = False
992- parent_checkpoint .is_existent .return_value = False
993- parent_checkpoint .is_replay_children .return_value = False
994-
995- child_checkpoint = Mock ()
996- child_checkpoint .is_succeeded .return_value = False
997- child_checkpoint .is_failed .return_value = False
998- child_checkpoint .is_existent .return_value = False
999- child_checkpoint .is_replay_children .return_value = False
1000-
1001- def get_checkpoint (op_id ):
1002- return child_checkpoint if op_id .startswith ("child-" ) else parent_checkpoint
1003-
1004- mock_state = Mock ()
1005- mock_state .durable_execution_arn = "arn:test"
1006- mock_state .get_checkpoint_result = Mock (side_effect = get_checkpoint )
1007- mock_state .create_checkpoint = Mock ()
1008-
1009- context_map = {}
1010-
1011- def create_id (self , i ):
1012- ctx_id = id (self )
1013- if ctx_id not in context_map :
1014- context_map [ctx_id ] = []
1015- context_map [ctx_id ].append (i )
1016- return (
1017- "parent"
1018- if len (context_map ) == 1 and len (context_map [ctx_id ]) == 1
1019- else f"child-{ i } "
1020- )
1021-
1022- with patch .object (
1023- DurableContext , "_create_step_id_for_logical_step" , create_id
1024- ):
1025- context = create_test_context (state = mock_state )
1026- result = context .map (["a" , "b" ], lambda ctx , item , idx , items : item )
1027-
1028- assert len (mock_serdes_serialize .call_args_list ) == 3
1029- parent_call = mock_serdes_serialize .call_args_list [2 ]
1030- assert parent_call [1 ]["value" ] is result
1023+ with patch .object (
1024+ DurableContext , "_create_step_id_for_logical_step" , create_id
1025+ ):
1026+ context = create_test_context (state = mock_state )
1027+ result = context .map (["a" , "b" ], lambda ctx , item , idx , items : item )
1028+
1029+ assert len (mock_serdes_serialize .call_args_list ) == 3
1030+ parent_call = mock_serdes_serialize .call_args_list [2 ]
1031+ assert parent_call [1 ]["value" ] is result
1032+ finally :
1033+ importlib .reload (child )
10311034
10321035
10331036def test_map_default_serdes_serializes_batch_result ():
10341037 """Verify default serdes automatically serializes BatchResult."""
1038+ try :
1039+ with patch (
1040+ "aws_durable_execution_sdk_python.serdes.serialize" , wraps = serialize
1041+ ) as mock_serialize :
1042+ importlib .reload (child )
1043+
1044+ parent_checkpoint = Mock ()
1045+ parent_checkpoint .is_succeeded .return_value = False
1046+ parent_checkpoint .is_failed .return_value = False
1047+ parent_checkpoint .is_existent .return_value = False
1048+ parent_checkpoint .is_replay_children .return_value = False
1049+
1050+ child_checkpoint = Mock ()
1051+ child_checkpoint .is_succeeded .return_value = False
1052+ child_checkpoint .is_failed .return_value = False
1053+ child_checkpoint .is_existent .return_value = False
1054+ child_checkpoint .is_replay_children .return_value = False
1055+
1056+ def get_checkpoint (op_id ):
1057+ return child_checkpoint if op_id .startswith ("child-" ) else parent_checkpoint
1058+
1059+ mock_state = Mock ()
1060+ mock_state .durable_execution_arn = "arn:test"
1061+ mock_state .get_checkpoint_result = Mock (side_effect = get_checkpoint )
1062+ mock_state .create_checkpoint = Mock ()
1063+
1064+ context_map = {}
1065+
1066+ def create_id (self , i ):
1067+ ctx_id = id (self )
1068+ if ctx_id not in context_map :
1069+ context_map [ctx_id ] = []
1070+ context_map [ctx_id ].append (i )
1071+ return (
1072+ "parent"
1073+ if len (context_map ) == 1 and len (context_map [ctx_id ]) == 1
1074+ else f"child-{ i } "
1075+ )
10351076
1036- with patch (
1037- "aws_durable_execution_sdk_python.serdes.serialize" , wraps = serialize
1038- ) as mock_serialize :
1077+ with patch .object (
1078+ DurableContext , "_create_step_id_for_logical_step" , create_id
1079+ ):
1080+ context = create_test_context (state = mock_state )
1081+ result = context .map (["a" , "b" ], lambda ctx , item , idx , items : item )
1082+
1083+ assert isinstance (result , BatchResult )
1084+ assert len (mock_serialize .call_args_list ) == 3
1085+ parent_call = mock_serialize .call_args_list [2 ]
1086+ assert parent_call [1 ]["serdes" ] is None
1087+ assert isinstance (parent_call [1 ]["value" ], BatchResult )
1088+ assert parent_call [1 ]["value" ] is result
1089+ finally :
10391090 importlib .reload (child )
10401091
1041- parent_checkpoint = Mock ()
1042- parent_checkpoint .is_succeeded .return_value = False
1043- parent_checkpoint .is_failed .return_value = False
1044- parent_checkpoint .is_existent .return_value = False
1045- parent_checkpoint .is_replay_children .return_value = False
1046-
1047- child_checkpoint = Mock ()
1048- child_checkpoint .is_succeeded .return_value = False
1049- child_checkpoint .is_failed .return_value = False
1050- child_checkpoint .is_existent .return_value = False
1051- child_checkpoint .is_replay_children .return_value = False
1052-
1053- def get_checkpoint (op_id ):
1054- return child_checkpoint if op_id .startswith ("child-" ) else parent_checkpoint
1055-
1056- mock_state = Mock ()
1057- mock_state .durable_execution_arn = "arn:test"
1058- mock_state .get_checkpoint_result = Mock (side_effect = get_checkpoint )
1059- mock_state .create_checkpoint = Mock ()
1060-
1061- context_map = {}
1062-
1063- def create_id (self , i ):
1064- ctx_id = id (self )
1065- if ctx_id not in context_map :
1066- context_map [ctx_id ] = []
1067- context_map [ctx_id ].append (i )
1068- return (
1069- "parent"
1070- if len (context_map ) == 1 and len (context_map [ctx_id ]) == 1
1071- else f"child-{ i } "
1072- )
1073-
1074- with patch .object (
1075- DurableContext , "_create_step_id_for_logical_step" , create_id
1076- ):
1077- context = create_test_context (state = mock_state )
1078- result = context .map (["a" , "b" ], lambda ctx , item , idx , items : item )
1079-
1080- assert isinstance (result , BatchResult )
1081- assert len (mock_serialize .call_args_list ) == 3
1082- parent_call = mock_serialize .call_args_list [2 ]
1083- assert parent_call [1 ]["serdes" ] is None
1084- assert isinstance (parent_call [1 ]["value" ], BatchResult )
1085- assert parent_call [1 ]["value" ] is result
1086-
10871092
10881093def test_map_custom_serdes_serializes_batch_result ():
10891094 """Verify custom serdes is used for BatchResult serialization."""
10901095
10911096 custom_serdes = CustomStrSerDes ()
10921097
1093- with patch ("aws_durable_execution_sdk_python.serdes.serialize" ) as mock_serialize :
1094- mock_serialize .return_value = '"serialized"'
1095- importlib .reload (child )
1098+ try :
1099+ with patch ("aws_durable_execution_sdk_python.serdes.serialize" ) as mock_serialize :
1100+ mock_serialize .return_value = '"serialized"'
1101+ importlib .reload (child )
1102+
1103+ parent_checkpoint = Mock ()
1104+ parent_checkpoint .is_succeeded .return_value = False
1105+ parent_checkpoint .is_failed .return_value = False
1106+ parent_checkpoint .is_existent .return_value = False
1107+ parent_checkpoint .is_replay_children .return_value = False
1108+
1109+ child_checkpoint = Mock ()
1110+ child_checkpoint .is_succeeded .return_value = False
1111+ child_checkpoint .is_failed .return_value = False
1112+ child_checkpoint .is_existent .return_value = False
1113+ child_checkpoint .is_replay_children .return_value = False
1114+
1115+ def get_checkpoint (op_id ):
1116+ return child_checkpoint if op_id .startswith ("child-" ) else parent_checkpoint
1117+
1118+ mock_state = Mock ()
1119+ mock_state .durable_execution_arn = "arn:test"
1120+ mock_state .get_checkpoint_result = Mock (side_effect = get_checkpoint )
1121+ mock_state .create_checkpoint = Mock ()
1122+
1123+ context_map = {}
1124+
1125+ def create_id (self , i ):
1126+ ctx_id = id (self )
1127+ if ctx_id not in context_map :
1128+ context_map [ctx_id ] = []
1129+ context_map [ctx_id ].append (i )
1130+ return (
1131+ "parent"
1132+ if len (context_map ) == 1 and len (context_map [ctx_id ]) == 1
1133+ else f"child-{ i } "
1134+ )
10961135
1097- parent_checkpoint = Mock ()
1098- parent_checkpoint .is_succeeded .return_value = False
1099- parent_checkpoint .is_failed .return_value = False
1100- parent_checkpoint .is_existent .return_value = False
1101- parent_checkpoint .is_replay_children .return_value = False
1102-
1103- child_checkpoint = Mock ()
1104- child_checkpoint .is_succeeded .return_value = False
1105- child_checkpoint .is_failed .return_value = False
1106- child_checkpoint .is_existent .return_value = False
1107- child_checkpoint .is_replay_children .return_value = False
1108-
1109- def get_checkpoint (op_id ):
1110- return child_checkpoint if op_id .startswith ("child-" ) else parent_checkpoint
1111-
1112- mock_state = Mock ()
1113- mock_state .durable_execution_arn = "arn:test"
1114- mock_state .get_checkpoint_result = Mock (side_effect = get_checkpoint )
1115- mock_state .create_checkpoint = Mock ()
1116-
1117- context_map = {}
1118-
1119- def create_id (self , i ):
1120- ctx_id = id (self )
1121- if ctx_id not in context_map :
1122- context_map [ctx_id ] = []
1123- context_map [ctx_id ].append (i )
1124- return (
1125- "parent"
1126- if len (context_map ) == 1 and len (context_map [ctx_id ]) == 1
1127- else f"child-{ i } "
1128- )
1129-
1130- with patch .object (
1131- DurableContext , "_create_step_id_for_logical_step" , create_id
1132- ):
1133- context = create_test_context (state = mock_state )
1134- result = context .map (
1135- ["a" , "b" ],
1136- lambda ctx , item , idx , items : item ,
1137- config = MapConfig (serdes = custom_serdes ),
1138- )
1139-
1140- assert isinstance (result , BatchResult )
1141- assert len (mock_serialize .call_args_list ) == 3
1142- parent_call = mock_serialize .call_args_list [2 ]
1143- assert parent_call [1 ]["serdes" ] is custom_serdes
1144- assert isinstance (parent_call [1 ]["value" ], BatchResult )
1136+ with patch .object (
1137+ DurableContext , "_create_step_id_for_logical_step" , create_id
1138+ ):
1139+ context = create_test_context (state = mock_state )
1140+ result = context .map (
1141+ ["a" , "b" ],
1142+ lambda ctx , item , idx , items : item ,
1143+ config = MapConfig (serdes = custom_serdes ),
1144+ )
1145+
1146+ assert isinstance (result , BatchResult )
1147+ assert len (mock_serialize .call_args_list ) == 3
1148+ parent_call = mock_serialize .call_args_list [2 ]
1149+ assert parent_call [1 ]["serdes" ] is custom_serdes
1150+ assert isinstance (parent_call [1 ]["value" ], BatchResult )
1151+ finally :
1152+ importlib .reload (child )
11451153 assert parent_call [1 ]["value" ] is result
11461154
11471155
0 commit comments