Skip to content

Commit c38ae3c

Browse files
committed
perf: coalesce empty checkpoints in batch collector - Empty checkpoints (used by map/parallel branch resubmitters when resuming from timed waits) no longer count toward the 250-operation batch limit beyond the first. This prevents 300+ concurrent branch resumes from splitting across multiple API batches.
1 parent 79fcb95 commit c38ae3c

3 files changed

Lines changed: 483 additions & 30 deletions

File tree

src/aws_durable_execution_sdk_python/state.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -592,15 +592,21 @@ def checkpoint_batches_forever(self) -> None:
592592
batch: list[QueuedOperation] = self._collect_checkpoint_batch()
593593

594594
if batch:
595-
# Extract OperationUpdates from QueuedOperations for API call
596-
updates: list[OperationUpdate] = [
597-
q.operation_update for q in batch if q.operation_update is not None
598-
]
595+
# Extract OperationUpdates, excluding empty checkpoints from API call
596+
updates: list[OperationUpdate] = []
597+
empty_count = 0
598+
599+
for q in batch:
600+
if q.operation_update is not None:
601+
updates.append(q.operation_update)
602+
else:
603+
empty_count += 1
599604

600605
logger.debug(
601-
"Processing checkpoint batch with %d operations (%d non-empty)",
602-
len(batch),
606+
"Sending %d OperationUpdates out of %d operations, excluding %d empty checkpoints",
603607
len(updates),
608+
len(batch),
609+
empty_count,
604610
)
605611

606612
try:
@@ -687,26 +693,41 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
687693
operation if queues are empty, then collects additional operations within the time
688694
window.
689695
696+
Empty checkpoints (operation_update=None) are coalesced: the first empty checkpoint
697+
counts toward the batch operation limit, but subsequent empty checkpoints do not.
698+
All empty checkpoints remain in the batch so their completion events are signaled.
699+
This avoids unnecessary batches when many concurrent map/parallel branches resume
700+
simultaneously and each queues an empty checkpoint.
701+
690702
Returns:
691703
List of QueuedOperation objects ready for batch processing. Returns empty list
692704
if no operations are available.
693705
"""
694706
batch: list[QueuedOperation] = []
707+
has_empty_checkpoint = False
695708
total_size = 0
709+
effective_operation_count = 0 # Operations that count toward batch limit
696710

697711
# First, drain overflow queue (FIFO order preserved)
698712
try:
699-
while len(batch) < self._batcher_config.max_batch_operations:
713+
while effective_operation_count < self._batcher_config.max_batch_operations:
700714
overflow_op = self._overflow_queue.get_nowait()
701-
op_size = self._calculate_operation_size(overflow_op)
702-
703-
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
704-
# Put back and stop
705-
self._overflow_queue.put(overflow_op)
706-
break
707715

708-
batch.append(overflow_op)
709-
total_size += op_size
716+
if overflow_op.operation_update is None: # Empty checkpoint
717+
batch.append(overflow_op)
718+
if not has_empty_checkpoint:
719+
effective_operation_count += 1 # First empty counts toward limit
720+
has_empty_checkpoint = True
721+
# Subsequent empties don't count toward limit
722+
else:
723+
op_size = self._calculate_operation_size(overflow_op)
724+
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
725+
# Put back and stop
726+
self._overflow_queue.put(overflow_op)
727+
break
728+
batch.append(overflow_op)
729+
total_size += op_size
730+
effective_operation_count += 1
710731
except queue.Empty:
711732
pass
712733

@@ -720,7 +741,13 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
720741
) # Check stop signal every 100ms
721742
self._checkpoint_queue.task_done()
722743
batch.append(first_op)
723-
total_size += self._calculate_operation_size(first_op)
744+
745+
if first_op.operation_update is None:
746+
has_empty_checkpoint = True
747+
else:
748+
total_size += self._calculate_operation_size(first_op)
749+
750+
effective_operation_count = 1
724751
break
725752
except queue.Empty:
726753
continue
@@ -735,7 +762,7 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
735762
# Collect additional operations within the time window
736763
while (
737764
time.time() < batch_deadline
738-
and len(batch) < self._batcher_config.max_batch_operations
765+
and effective_operation_count < self._batcher_config.max_batch_operations
739766
and not self._checkpointing_stopped.is_set()
740767
):
741768
remaining_time = min(
@@ -749,26 +776,37 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
749776
try:
750777
additional_op = self._checkpoint_queue.get(timeout=remaining_time)
751778
self._checkpoint_queue.task_done()
752-
op_size = self._calculate_operation_size(additional_op)
753-
754-
# Check if adding this operation would exceed size limit
755-
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
756-
# Put in overflow queue for next batch
757-
self._overflow_queue.put(additional_op)
758-
logger.debug(
759-
"Batch size limit reached, moving operation to overflow queue"
760-
)
761-
break
762779

763-
batch.append(additional_op)
764-
total_size += op_size
780+
if additional_op.operation_update is None: # Empty checkpoint
781+
batch.append(additional_op)
782+
if not has_empty_checkpoint:
783+
effective_operation_count += 1 # First empty counts toward limit
784+
has_empty_checkpoint = True
785+
# Subsequent empties don't count toward limit
786+
else:
787+
op_size = self._calculate_operation_size(additional_op)
788+
# Check if adding this operation would exceed size limit
789+
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
790+
# Put in overflow queue for next batch
791+
self._overflow_queue.put(additional_op)
792+
logger.debug(
793+
"Batch size limit reached, moving operation to overflow queue"
794+
)
795+
break
796+
batch.append(additional_op)
797+
total_size += op_size
798+
effective_operation_count += 1
765799

766800
except queue.Empty:
767801
break
768802

803+
empty_count = sum(1 for q in batch if q.operation_update is None)
769804
logger.debug(
770-
"Collected batch of %d operations, total size: %d bytes",
805+
"Collected batch of %d operations (%d effective, %d non-empty, %d empty), total size: %d bytes",
771806
len(batch),
807+
effective_operation_count,
808+
len(batch) - empty_count,
809+
empty_count,
772810
total_size,
773811
)
774812
return batch
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""Integration test: empty checkpoint coalescing with concurrent map + wait.
2+
3+
Python equivalent of the Java MapWithConditionAndCallbackExample referenced in
4+
issue #325. Verifies that when many concurrent map branches resume from timed
5+
wait operations simultaneously, the empty checkpoints produced by the
6+
resubmitter (executor.py) are coalesced into minimal API calls instead of
7+
being split across multiple batches.
8+
9+
Background
10+
----------
11+
When a map branch suspends via TimedSuspendExecution and later resumes, the
12+
ConcurrentExecutor resubmitter calls::
13+
14+
execution_state.create_checkpoint() # empty checkpoint
15+
16+
before resubmitting the branch. In high-concurrency scenarios (300+ branches)
17+
all resuming at the same time, 300+ empty checkpoints flood the checkpoint
18+
queue.
19+
20+
Without the coalescing optimization (issue #325), the 250-operation batch limit
21+
causes these to be split across multiple batches → multiple API calls.
22+
With the optimization, all subsequent empty checkpoints beyond the first do
23+
NOT count toward the batch limit, so they are coalesced into a single batch
24+
and a single API call.
25+
26+
These tests directly simulate that concurrent-checkpoint pattern by launching
27+
many threads that each call ``create_checkpoint()`` simultaneously, mirroring
28+
what the map resubmitter does when all branches resume at once.
29+
"""
30+
31+
from __future__ import annotations
32+
33+
import threading
34+
from concurrent.futures import ThreadPoolExecutor
35+
36+
import pytest
37+
38+
from aws_durable_execution_sdk_python.lambda_service import (
39+
CheckpointOutput,
40+
CheckpointUpdatedExecutionState,
41+
LambdaClient,
42+
OperationAction,
43+
OperationUpdate,
44+
OperationType,
45+
)
46+
from aws_durable_execution_sdk_python.state import (
47+
CheckpointBatcherConfig,
48+
ExecutionState,
49+
QueuedOperation,
50+
)
51+
from aws_durable_execution_sdk_python.threading import CompletionEvent
52+
53+
from unittest.mock import Mock
54+
55+
56+
def _make_state(
57+
mock_client: Mock,
58+
batch_time: float = 5.0,
59+
max_ops: int = 250,
60+
) -> ExecutionState:
61+
config = CheckpointBatcherConfig(
62+
max_batch_size_bytes=10 * 1024 * 1024,
63+
max_batch_time_seconds=batch_time,
64+
max_batch_operations=max_ops,
65+
)
66+
return ExecutionState(
67+
durable_execution_arn="test-arn",
68+
initial_checkpoint_token="token-0", # noqa: S106
69+
operations={},
70+
service_client=mock_client,
71+
batcher_config=config,
72+
)
73+
74+
75+
def _make_tracking_client() -> tuple[Mock, list]:
76+
"""Return a (mock LambdaClient, checkpoint_calls list) pair."""
77+
calls: list[list] = []
78+
mock_client = Mock(spec=LambdaClient)
79+
80+
def _checkpoint(durable_execution_arn, checkpoint_token, updates, client_token=None):
81+
calls.append(list(updates))
82+
return CheckpointOutput(
83+
checkpoint_token=f"token_{len(calls)}",
84+
new_execution_state=CheckpointUpdatedExecutionState(),
85+
)
86+
87+
mock_client.checkpoint = _checkpoint
88+
return mock_client, calls
89+
90+
91+
def test_map_with_concurrent_waits_coalesces_empty_checkpoints():
92+
"""300 concurrent branches all create empty checkpoints simultaneously.
93+
94+
Simulates the Java MapWithConditionAndCallbackExample scenario: 300 map
95+
branches all resuming from a wait operation at the same time, each calling
96+
the resubmitter which enqueues an empty checkpoint.
97+
98+
Without the coalescing optimization, the 250-op batch limit splits 300
99+
empty checkpoints into 2 batches (250 + 50) → 2 API calls.
100+
With the optimization (effective_operation_count stays 1 for empties),
101+
all 300 are collected in a single batch → 1 API call.
102+
"""
103+
mock_client, calls = _make_tracking_client()
104+
state = _make_state(mock_client, batch_time=5.0, max_ops=250)
105+
106+
batcher = ThreadPoolExecutor(max_workers=1)
107+
batcher.submit(state.checkpoint_batches_forever)
108+
109+
# 300 branches all call create_checkpoint() concurrently, each blocking
110+
# until the batch is processed — mirrors the resubmitter pattern.
111+
branch_count = 300
112+
start_barrier = threading.Barrier(branch_count)
113+
errors: list[Exception] = []
114+
115+
def branch_work():
116+
try:
117+
start_barrier.wait() # all start simultaneously
118+
state.create_checkpoint() # empty checkpoint, synchronous
119+
except Exception as e: # noqa: BLE001
120+
errors.append(e)
121+
122+
threads = [threading.Thread(target=branch_work) for _ in range(branch_count)]
123+
for t in threads:
124+
t.start()
125+
for t in threads:
126+
t.join(timeout=30)
127+
128+
try:
129+
assert not errors, f"Branch errors: {errors}"
130+
131+
# All 300 empty checkpoints should be batched into 1 API call.
132+
# Without the fix, 300 > 250 limit would produce 2 calls.
133+
assert len(calls) == 1, (
134+
f"Expected 1 coalesced API call for {branch_count} concurrent empty "
135+
f"checkpoints, got {len(calls)}. The 250-op limit must not split empties."
136+
)
137+
assert calls[0] == [], "Empty checkpoints should produce an empty updates list"
138+
finally:
139+
state.stop_checkpointing()
140+
batcher.shutdown(wait=True)
141+
142+
143+
def test_map_with_concurrent_waits_api_call_count_scales_with_real_ops_not_empties():
144+
"""400 empty checkpoints + 10 real ops → 1 API call with limit=11.
145+
146+
Demonstrates that the effective batch count is driven by real operations
147+
(and only the *first* empty), not the total number of empties.
148+
149+
With limit=11: the first empty counts as effective_op 1, and each of the
150+
10 real ops increments the count (effective_ops 2–11). The limit is hit
151+
exactly when the last real op is collected. All 399 remaining empties are
152+
coalesced in without incrementing the count.
153+
154+
Result: 1 batch (410 operations, 10 real) → 1 API call.
155+
"""
156+
mock_client, calls = _make_tracking_client()
157+
# limit = 1 (first empty) + 10 (real ops) = 11, so all fit in one batch
158+
state = _make_state(mock_client, batch_time=5.0, max_ops=11)
159+
160+
batcher = ThreadPoolExecutor(max_workers=1)
161+
batcher.submit(state.checkpoint_batches_forever)
162+
163+
completion_events: list[CompletionEvent] = []
164+
165+
try:
166+
# 400 empty checkpoints (simulating concurrent branch resumes)
167+
for _ in range(400):
168+
ev = CompletionEvent()
169+
completion_events.append(ev)
170+
state._checkpoint_queue.put(QueuedOperation(None, ev)) # noqa: SLF001
171+
172+
# 10 real operations alongside the empties
173+
for i in range(10):
174+
op = OperationUpdate(
175+
operation_id=f"op_{i}",
176+
operation_type=OperationType.STEP,
177+
action=OperationAction.START,
178+
)
179+
ev = CompletionEvent()
180+
completion_events.append(ev)
181+
state._checkpoint_queue.put(QueuedOperation(op, ev)) # noqa: SLF001
182+
183+
# Wait for all 410 to be processed
184+
for ev in completion_events:
185+
ev.wait()
186+
187+
# 1 empty (effective=1) + 10 real ops (effective=11) exhaust the batch
188+
# limit exactly. The 399 remaining empties coalesce in → still 1 API call.
189+
assert len(calls) == 1, (
190+
f"Expected 1 API call with 400 empty + 10 real ops (limit=11), "
191+
f"got {len(calls)}."
192+
)
193+
# Only the 10 real ops appear in the updates list; empties are excluded.
194+
real_op_ids = {u.operation_id for batch in calls for u in batch}
195+
assert real_op_ids == {f"op_{i}" for i in range(10)}
196+
finally:
197+
state.stop_checkpointing()
198+
batcher.shutdown(wait=True)

0 commit comments

Comments
 (0)