Skip to content

Commit 12b5d66

Browse files
committed
fix: lock operations access in ExecutionState
- Make operations private (_operations); add a read-only snapshot property to preserve the public attribute - Read operations under _operations_lock in track_replay and get_execution_operation, closing a dictionary-changed-size race against the concurrent checkpoint update path - Add regression test for concurrent track_replay and update
1 parent c577675 commit 12b5d66

2 files changed

Lines changed: 116 additions & 19 deletions

File tree

packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def __init__(
250250
):
251251
self.durable_execution_arn: str = durable_execution_arn
252252
self._current_checkpoint_token: str = initial_checkpoint_token
253-
self.operations: MutableMapping[str, Operation] = operations
253+
self._operations: dict[str, Operation] = dict(operations)
254254
self._service_client: DurableServiceClient = service_client
255255
self._plugin_executor: PluginExecutor = plugin_executor
256256
self._ordered_checkpoint_lock: OrderedLock = OrderedLock()
@@ -279,6 +279,16 @@ def __init__(
279279
self._replay_status_lock: Lock = Lock()
280280
self._visited_operations: set[str] = set()
281281

282+
@property
283+
def operations(self) -> dict[str, Operation]:
284+
"""Return a point-in-time snapshot copy of the operations map.
285+
286+
The returned dict is a copy, so mutating it does not affect execution
287+
state and iterating it is safe against concurrent updates.
288+
"""
289+
with self._operations_lock:
290+
return dict(self._operations)
291+
282292
def fetch_paginated_operations(
283293
self,
284294
initial_operations: list[Operation],
@@ -324,7 +334,7 @@ def fetch_paginated_operations(
324334
# Always store whatever operations we successfully fetched
325335
if all_operations:
326336
with self._operations_lock:
327-
self.operations.update(
337+
self._operations.update(
328338
{op.operation_id: op for op in all_operations}
329339
)
330340
return all_operations
@@ -341,7 +351,8 @@ def get_input_payload(self) -> str | None:
341351
def get_execution_operation(self) -> Operation | None:
342352
# invocation id is id of execution operation
343353
invocation_id = self.durable_execution_arn.split("/")[-1]
344-
candidate = self.operations.get(invocation_id)
354+
with self._operations_lock:
355+
candidate = self._operations.get(invocation_id)
345356
if not candidate:
346357
# Due to payload size limitations we may have an empty operations list.
347358
# This will only happen when loading the initial page of results and is
@@ -370,19 +381,21 @@ def track_replay(self, operation_id: str) -> None:
370381
with self._replay_status_lock:
371382
if self._replay_status == ReplayStatus.REPLAY:
372383
self._visited_operations.add(operation_id)
373-
completed_ops = {
374-
op_id
375-
for op_id, op in self.operations.items()
376-
if op.operation_type != OperationType.EXECUTION
377-
and op.status
378-
in {
379-
OperationStatus.SUCCEEDED,
380-
OperationStatus.FAILED,
381-
OperationStatus.CANCELLED,
382-
OperationStatus.STOPPED,
383-
OperationStatus.TIMED_OUT,
384+
# Lock order: _replay_status_lock then _operations_lock.
385+
with self._operations_lock:
386+
completed_ops = {
387+
op_id
388+
for op_id, op in self._operations.items()
389+
if op.operation_type != OperationType.EXECUTION
390+
and op.status
391+
in {
392+
OperationStatus.SUCCEEDED,
393+
OperationStatus.FAILED,
394+
OperationStatus.CANCELLED,
395+
OperationStatus.STOPPED,
396+
OperationStatus.TIMED_OUT,
397+
}
384398
}
385-
}
386399
if completed_ops.issubset(self._visited_operations):
387400
logger.debug(
388401
"Transitioning from REPLAY to NEW status at operation %s",
@@ -404,7 +417,7 @@ def mark_replaying_if_prior_operations_exist(self) -> None:
404417
with self._operations_lock:
405418
has_prior_operations: bool = any(
406419
op.operation_type is not OperationType.EXECUTION
407-
for op in self.operations.values()
420+
for op in self._operations.values()
408421
)
409422

410423
if has_prior_operations:
@@ -431,7 +444,7 @@ def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult:
431444
"""
432445
# checking status are deliberately under a lighter non-serialized lock
433446
with self._operations_lock:
434-
if checkpoint := self.operations.get(checkpoint_id):
447+
if checkpoint := self._operations.get(checkpoint_id):
435448
return CheckpointedResult.create_from_operation(checkpoint)
436449

437450
return CHECKPOINT_NOT_FOUND

packages/aws-durable-execution-sdk-python/tests/state_test.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ def test_concurrent_access_to_operations_dictionary():
13971397
operation_type=OperationType.STEP,
13981398
status=OperationStatus.SUCCEEDED,
13991399
)
1400-
state.operations["op1"] = operation
1400+
state._operations["op1"] = operation
14011401

14021402
results = []
14031403
errors = []
@@ -1422,7 +1422,7 @@ def writer_thread():
14221422
status=OperationStatus.SUCCEEDED,
14231423
)
14241424
with state._operations_lock:
1425-
state.operations[f"op{i}"] = new_op
1425+
state._operations[f"op{i}"] = new_op
14261426
time.sleep(0.001)
14271427
except Exception as e:
14281428
errors.append(e)
@@ -4260,3 +4260,87 @@ def test_plugin_executor_not_called_for_pending_operations():
42604260

42614261

42624262
# endregion Plugin Executor Integration Tests
4263+
4264+
4265+
def _make_execution_state_for_operations(
4266+
mock_lambda_client, *, replay_status=ReplayStatus.NEW, operations=None
4267+
):
4268+
return ExecutionState(
4269+
durable_execution_arn="test_arn",
4270+
initial_checkpoint_token="token123", # noqa: S106
4271+
operations=operations or {},
4272+
service_client=mock_lambda_client,
4273+
plugin_executor=PluginExecutor(plugins=None),
4274+
replay_status=replay_status,
4275+
)
4276+
4277+
4278+
def test_operations_property_returns_snapshot_copy():
4279+
"""The operations property exposes a copy; mutating it must not affect state."""
4280+
mock_lambda_client = Mock(spec=LambdaClient)
4281+
op = Operation(
4282+
operation_id="op1",
4283+
operation_type=OperationType.STEP,
4284+
status=OperationStatus.SUCCEEDED,
4285+
)
4286+
state = _make_execution_state_for_operations(
4287+
mock_lambda_client, operations={"op1": op}
4288+
)
4289+
4290+
snapshot = state.operations
4291+
assert snapshot == {"op1": op}
4292+
4293+
snapshot["op2"] = op # mutating the returned copy must not leak into state
4294+
assert "op2" not in state.operations
4295+
assert len(state.operations) == 1
4296+
4297+
4298+
def test_track_replay_iteration_safe_under_concurrent_update():
4299+
"""track_replay must not raise when operations are updated concurrently.
4300+
4301+
A worker thread iterates operations inside track_replay while the checkpoint
4302+
path updates the same map. Without consistent locking this raises
4303+
"dictionary changed size during iteration".
4304+
"""
4305+
mock_lambda_client = Mock(spec=LambdaClient)
4306+
state = _make_execution_state_for_operations(
4307+
mock_lambda_client, replay_status=ReplayStatus.REPLAY
4308+
)
4309+
# Seed completed operations so track_replay keeps iterating (stays REPLAY).
4310+
for i in range(50):
4311+
state._operations[f"seed{i}"] = Operation(
4312+
operation_id=f"seed{i}",
4313+
operation_type=OperationType.STEP,
4314+
status=OperationStatus.SUCCEEDED,
4315+
)
4316+
4317+
errors: list[Exception] = []
4318+
stop = threading.Event()
4319+
4320+
def writer():
4321+
i = 0
4322+
while not stop.is_set():
4323+
with state._operations_lock:
4324+
state._operations[f"w{i}"] = Operation(
4325+
operation_id=f"w{i}",
4326+
operation_type=OperationType.STEP,
4327+
status=OperationStatus.SUCCEEDED,
4328+
)
4329+
i += 1
4330+
4331+
def reader():
4332+
try:
4333+
for _ in range(2000):
4334+
state.track_replay(operation_id="probe")
4335+
except Exception as e: # noqa: BLE001
4336+
errors.append(e)
4337+
4338+
writer_t = threading.Thread(target=writer, daemon=True)
4339+
reader_t = threading.Thread(target=reader, daemon=True)
4340+
writer_t.start()
4341+
reader_t.start()
4342+
reader_t.join(timeout=30)
4343+
stop.set()
4344+
writer_t.join(timeout=5)
4345+
4346+
assert not errors, f"track_replay raced with concurrent update: {errors}"

0 commit comments

Comments
 (0)