Skip to content

Commit 72af430

Browse files
committed
fix: fix failing tests due to test pollution
1 parent b1c55ae commit 72af430

2 files changed

Lines changed: 310 additions & 294 deletions

File tree

packages/aws-durable-execution-sdk-python/tests/operation/map_test.py

Lines changed: 155 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -980,168 +980,176 @@ def get_checkpoint_result(self, operation_id):
980980

981981
def test_map_handler_serializes_batch_result():
982982
"""Verify map_handler serializes BatchResult at parent level."""
983-
with patch(
984-
"aws_durable_execution_sdk_python.serdes.serialize"
985-
) as mock_serdes_serialize:
986-
mock_serdes_serialize.return_value = '"serialized"'
987-
importlib.reload(child)
983+
try:
984+
with patch(
985+
"aws_durable_execution_sdk_python.serdes.serialize"
986+
) as mock_serdes_serialize:
987+
mock_serdes_serialize.return_value = '"serialized"'
988+
importlib.reload(child)
989+
990+
parent_checkpoint = Mock()
991+
parent_checkpoint.is_succeeded.return_value = False
992+
parent_checkpoint.is_failed.return_value = False
993+
parent_checkpoint.is_existent.return_value = False
994+
parent_checkpoint.is_replay_children.return_value = False
995+
996+
child_checkpoint = Mock()
997+
child_checkpoint.is_succeeded.return_value = False
998+
child_checkpoint.is_failed.return_value = False
999+
child_checkpoint.is_existent.return_value = False
1000+
child_checkpoint.is_replay_children.return_value = False
1001+
1002+
def get_checkpoint(op_id):
1003+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
1004+
1005+
mock_state = Mock()
1006+
mock_state.durable_execution_arn = "arn:test"
1007+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
1008+
mock_state.create_checkpoint = Mock()
1009+
1010+
context_map = {}
1011+
1012+
def create_id(self, i):
1013+
ctx_id = id(self)
1014+
if ctx_id not in context_map:
1015+
context_map[ctx_id] = []
1016+
context_map[ctx_id].append(i)
1017+
return (
1018+
"parent"
1019+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
1020+
else f"child-{i}"
1021+
)
9881022

989-
parent_checkpoint = Mock()
990-
parent_checkpoint.is_succeeded.return_value = False
991-
parent_checkpoint.is_failed.return_value = False
992-
parent_checkpoint.is_existent.return_value = False
993-
parent_checkpoint.is_replay_children.return_value = False
994-
995-
child_checkpoint = Mock()
996-
child_checkpoint.is_succeeded.return_value = False
997-
child_checkpoint.is_failed.return_value = False
998-
child_checkpoint.is_existent.return_value = False
999-
child_checkpoint.is_replay_children.return_value = False
1000-
1001-
def get_checkpoint(op_id):
1002-
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
1003-
1004-
mock_state = Mock()
1005-
mock_state.durable_execution_arn = "arn:test"
1006-
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
1007-
mock_state.create_checkpoint = Mock()
1008-
1009-
context_map = {}
1010-
1011-
def create_id(self, i):
1012-
ctx_id = id(self)
1013-
if ctx_id not in context_map:
1014-
context_map[ctx_id] = []
1015-
context_map[ctx_id].append(i)
1016-
return (
1017-
"parent"
1018-
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
1019-
else f"child-{i}"
1020-
)
1021-
1022-
with patch.object(
1023-
DurableContext, "_create_step_id_for_logical_step", create_id
1024-
):
1025-
context = create_test_context(state=mock_state)
1026-
result = context.map(["a", "b"], lambda ctx, item, idx, items: item)
1027-
1028-
assert len(mock_serdes_serialize.call_args_list) == 3
1029-
parent_call = mock_serdes_serialize.call_args_list[2]
1030-
assert parent_call[1]["value"] is result
1023+
with patch.object(
1024+
DurableContext, "_create_step_id_for_logical_step", create_id
1025+
):
1026+
context = create_test_context(state=mock_state)
1027+
result = context.map(["a", "b"], lambda ctx, item, idx, items: item)
1028+
1029+
assert len(mock_serdes_serialize.call_args_list) == 3
1030+
parent_call = mock_serdes_serialize.call_args_list[2]
1031+
assert parent_call[1]["value"] is result
1032+
finally:
1033+
importlib.reload(child)
10311034

10321035

10331036
def test_map_default_serdes_serializes_batch_result():
10341037
"""Verify default serdes automatically serializes BatchResult."""
1038+
try:
1039+
with patch(
1040+
"aws_durable_execution_sdk_python.serdes.serialize", wraps=serialize
1041+
) as mock_serialize:
1042+
importlib.reload(child)
1043+
1044+
parent_checkpoint = Mock()
1045+
parent_checkpoint.is_succeeded.return_value = False
1046+
parent_checkpoint.is_failed.return_value = False
1047+
parent_checkpoint.is_existent.return_value = False
1048+
parent_checkpoint.is_replay_children.return_value = False
1049+
1050+
child_checkpoint = Mock()
1051+
child_checkpoint.is_succeeded.return_value = False
1052+
child_checkpoint.is_failed.return_value = False
1053+
child_checkpoint.is_existent.return_value = False
1054+
child_checkpoint.is_replay_children.return_value = False
1055+
1056+
def get_checkpoint(op_id):
1057+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
1058+
1059+
mock_state = Mock()
1060+
mock_state.durable_execution_arn = "arn:test"
1061+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
1062+
mock_state.create_checkpoint = Mock()
1063+
1064+
context_map = {}
1065+
1066+
def create_id(self, i):
1067+
ctx_id = id(self)
1068+
if ctx_id not in context_map:
1069+
context_map[ctx_id] = []
1070+
context_map[ctx_id].append(i)
1071+
return (
1072+
"parent"
1073+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
1074+
else f"child-{i}"
1075+
)
10351076

1036-
with patch(
1037-
"aws_durable_execution_sdk_python.serdes.serialize", wraps=serialize
1038-
) as mock_serialize:
1077+
with patch.object(
1078+
DurableContext, "_create_step_id_for_logical_step", create_id
1079+
):
1080+
context = create_test_context(state=mock_state)
1081+
result = context.map(["a", "b"], lambda ctx, item, idx, items: item)
1082+
1083+
assert isinstance(result, BatchResult)
1084+
assert len(mock_serialize.call_args_list) == 3
1085+
parent_call = mock_serialize.call_args_list[2]
1086+
assert parent_call[1]["serdes"] is None
1087+
assert isinstance(parent_call[1]["value"], BatchResult)
1088+
assert parent_call[1]["value"] is result
1089+
finally:
10391090
importlib.reload(child)
10401091

1041-
parent_checkpoint = Mock()
1042-
parent_checkpoint.is_succeeded.return_value = False
1043-
parent_checkpoint.is_failed.return_value = False
1044-
parent_checkpoint.is_existent.return_value = False
1045-
parent_checkpoint.is_replay_children.return_value = False
1046-
1047-
child_checkpoint = Mock()
1048-
child_checkpoint.is_succeeded.return_value = False
1049-
child_checkpoint.is_failed.return_value = False
1050-
child_checkpoint.is_existent.return_value = False
1051-
child_checkpoint.is_replay_children.return_value = False
1052-
1053-
def get_checkpoint(op_id):
1054-
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
1055-
1056-
mock_state = Mock()
1057-
mock_state.durable_execution_arn = "arn:test"
1058-
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
1059-
mock_state.create_checkpoint = Mock()
1060-
1061-
context_map = {}
1062-
1063-
def create_id(self, i):
1064-
ctx_id = id(self)
1065-
if ctx_id not in context_map:
1066-
context_map[ctx_id] = []
1067-
context_map[ctx_id].append(i)
1068-
return (
1069-
"parent"
1070-
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
1071-
else f"child-{i}"
1072-
)
1073-
1074-
with patch.object(
1075-
DurableContext, "_create_step_id_for_logical_step", create_id
1076-
):
1077-
context = create_test_context(state=mock_state)
1078-
result = context.map(["a", "b"], lambda ctx, item, idx, items: item)
1079-
1080-
assert isinstance(result, BatchResult)
1081-
assert len(mock_serialize.call_args_list) == 3
1082-
parent_call = mock_serialize.call_args_list[2]
1083-
assert parent_call[1]["serdes"] is None
1084-
assert isinstance(parent_call[1]["value"], BatchResult)
1085-
assert parent_call[1]["value"] is result
1086-
10871092

10881093
def test_map_custom_serdes_serializes_batch_result():
10891094
"""Verify custom serdes is used for BatchResult serialization."""
10901095

10911096
custom_serdes = CustomStrSerDes()
10921097

1093-
with patch("aws_durable_execution_sdk_python.serdes.serialize") as mock_serialize:
1094-
mock_serialize.return_value = '"serialized"'
1095-
importlib.reload(child)
1098+
try:
1099+
with patch("aws_durable_execution_sdk_python.serdes.serialize") as mock_serialize:
1100+
mock_serialize.return_value = '"serialized"'
1101+
importlib.reload(child)
1102+
1103+
parent_checkpoint = Mock()
1104+
parent_checkpoint.is_succeeded.return_value = False
1105+
parent_checkpoint.is_failed.return_value = False
1106+
parent_checkpoint.is_existent.return_value = False
1107+
parent_checkpoint.is_replay_children.return_value = False
1108+
1109+
child_checkpoint = Mock()
1110+
child_checkpoint.is_succeeded.return_value = False
1111+
child_checkpoint.is_failed.return_value = False
1112+
child_checkpoint.is_existent.return_value = False
1113+
child_checkpoint.is_replay_children.return_value = False
1114+
1115+
def get_checkpoint(op_id):
1116+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
1117+
1118+
mock_state = Mock()
1119+
mock_state.durable_execution_arn = "arn:test"
1120+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
1121+
mock_state.create_checkpoint = Mock()
1122+
1123+
context_map = {}
1124+
1125+
def create_id(self, i):
1126+
ctx_id = id(self)
1127+
if ctx_id not in context_map:
1128+
context_map[ctx_id] = []
1129+
context_map[ctx_id].append(i)
1130+
return (
1131+
"parent"
1132+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
1133+
else f"child-{i}"
1134+
)
10961135

1097-
parent_checkpoint = Mock()
1098-
parent_checkpoint.is_succeeded.return_value = False
1099-
parent_checkpoint.is_failed.return_value = False
1100-
parent_checkpoint.is_existent.return_value = False
1101-
parent_checkpoint.is_replay_children.return_value = False
1102-
1103-
child_checkpoint = Mock()
1104-
child_checkpoint.is_succeeded.return_value = False
1105-
child_checkpoint.is_failed.return_value = False
1106-
child_checkpoint.is_existent.return_value = False
1107-
child_checkpoint.is_replay_children.return_value = False
1108-
1109-
def get_checkpoint(op_id):
1110-
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
1111-
1112-
mock_state = Mock()
1113-
mock_state.durable_execution_arn = "arn:test"
1114-
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
1115-
mock_state.create_checkpoint = Mock()
1116-
1117-
context_map = {}
1118-
1119-
def create_id(self, i):
1120-
ctx_id = id(self)
1121-
if ctx_id not in context_map:
1122-
context_map[ctx_id] = []
1123-
context_map[ctx_id].append(i)
1124-
return (
1125-
"parent"
1126-
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
1127-
else f"child-{i}"
1128-
)
1129-
1130-
with patch.object(
1131-
DurableContext, "_create_step_id_for_logical_step", create_id
1132-
):
1133-
context = create_test_context(state=mock_state)
1134-
result = context.map(
1135-
["a", "b"],
1136-
lambda ctx, item, idx, items: item,
1137-
config=MapConfig(serdes=custom_serdes),
1138-
)
1139-
1140-
assert isinstance(result, BatchResult)
1141-
assert len(mock_serialize.call_args_list) == 3
1142-
parent_call = mock_serialize.call_args_list[2]
1143-
assert parent_call[1]["serdes"] is custom_serdes
1144-
assert isinstance(parent_call[1]["value"], BatchResult)
1136+
with patch.object(
1137+
DurableContext, "_create_step_id_for_logical_step", create_id
1138+
):
1139+
context = create_test_context(state=mock_state)
1140+
result = context.map(
1141+
["a", "b"],
1142+
lambda ctx, item, idx, items: item,
1143+
config=MapConfig(serdes=custom_serdes),
1144+
)
1145+
1146+
assert isinstance(result, BatchResult)
1147+
assert len(mock_serialize.call_args_list) == 3
1148+
parent_call = mock_serialize.call_args_list[2]
1149+
assert parent_call[1]["serdes"] is custom_serdes
1150+
assert isinstance(parent_call[1]["value"], BatchResult)
1151+
finally:
1152+
importlib.reload(child)
11451153
assert parent_call[1]["value"] is result
11461154

11471155

0 commit comments

Comments
 (0)