Skip to content

Commit 48ff006

Browse files
committed
move batcher
Co-authored-by: thomas <18520168+yaythomas@users.noreply.github.com>
1 parent b4db34b commit 48ff006

3 files changed

Lines changed: 25 additions & 15 deletions

File tree

src/aws_durable_execution_sdk_python/state.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,9 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
716716
if overflow_op.operation_update is None: # Empty checkpoint
717717
batch.append(overflow_op)
718718
if not has_empty_checkpoint:
719-
effective_operation_count += 1 # First empty counts toward limit
719+
effective_operation_count += (
720+
1 # First empty counts toward limit
721+
)
720722
has_empty_checkpoint = True
721723
# Subsequent empties don't count toward limit
722724
else:
@@ -780,7 +782,9 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
780782
if additional_op.operation_update is None: # Empty checkpoint
781783
batch.append(additional_op)
782784
if not has_empty_checkpoint:
783-
effective_operation_count += 1 # First empty counts toward limit
785+
effective_operation_count += (
786+
1 # First empty counts toward limit
787+
)
784788
has_empty_checkpoint = True
785789
# Subsequent empties don't count toward limit
786790
else:

tests/e2e/map_with_concurrent_waits_int_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def _make_tracking_client() -> tuple[Mock, list]:
7676
calls: list[list] = []
7777
mock_client = Mock(spec=LambdaClient)
7878

79-
def _checkpoint(durable_execution_arn, checkpoint_token, updates, client_token=None):
79+
def _checkpoint(
80+
durable_execution_arn, checkpoint_token, updates, client_token=None
81+
):
8082
calls.append(list(updates))
8183
return CheckpointOutput(
8284
checkpoint_token=f"token_{len(calls)}",
@@ -113,9 +115,9 @@ def test_map_with_concurrent_waits_coalesces_empty_checkpoints():
113115

114116
def branch_work():
115117
try:
116-
start_barrier.wait() # all start simultaneously
117-
state.create_checkpoint() # empty checkpoint, synchronous
118-
except Exception as e: # noqa: BLE001
118+
start_barrier.wait() # all start simultaneously
119+
state.create_checkpoint() # empty checkpoint, synchronous
120+
except Exception as e: # noqa: BLE001
119121
errors.append(e)
120122

121123
threads = [threading.Thread(target=branch_work) for _ in range(branch_count)]
@@ -156,7 +158,6 @@ def test_map_with_concurrent_waits_api_call_count_scales_with_real_ops_not_empti
156158
# limit = 1 (first empty) + 10 (real ops) = 11, so all fit in one batch
157159
state = _make_state(mock_client, batch_time=5.0, max_ops=11)
158160

159-
160161
completion_events: list[CompletionEvent] = []
161162

162163
try:
@@ -173,13 +174,14 @@ def test_map_with_concurrent_waits_api_call_count_scales_with_real_ops_not_empti
173174
operation_type=OperationType.STEP,
174175
action=OperationAction.START,
175176
)
176-
batcher = ThreadPoolExecutor(max_workers=1)
177-
batcher.submit(state.checkpoint_batches_forever)
178-
177+
179178
ev = CompletionEvent()
180179
completion_events.append(ev)
181180
state._checkpoint_queue.put(QueuedOperation(op, ev)) # noqa: SLF001
182181

182+
batcher = ThreadPoolExecutor(max_workers=1)
183+
batcher.submit(state.checkpoint_batches_forever)
184+
183185
# Wait for all 410 to be processed
184186
for ev in completion_events:
185187
ev.wait()

tests/state_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3419,7 +3419,7 @@ def test_collect_checkpoint_batch_empty_checkpoints_with_real_ops_respects_limit
34193419
real_in_batch = sum(1 for q in batch if q.operation_update is not None)
34203420

34213421
assert empty_in_batch == 3 # All empty checkpoints coalesced
3422-
assert real_in_batch == 4 # 4 real ops (1 slot used by the first empty)
3422+
assert real_in_batch == 4 # 4 real ops (1 slot used by the first empty)
34233423

34243424

34253425
def test_collect_checkpoint_batch_overflow_coalesces_empty_checkpoints():
@@ -3540,9 +3540,11 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit():
35403540
operation_type=OperationType.STEP,
35413541
action=OperationAction.START,
35423542
)
3543-
state._checkpoint_queue.put(QueuedOperation(None, None)) # empty — effective=1
3544-
state._checkpoint_queue.put(QueuedOperation(op1, None)) # real — effective=2, limit hit
3545-
state._checkpoint_queue.put(QueuedOperation(op2, None)) # real — stays in queue
3543+
state._checkpoint_queue.put(QueuedOperation(None, None)) # empty — effective=1
3544+
state._checkpoint_queue.put(
3545+
QueuedOperation(op1, None)
3546+
) # real — effective=2, limit hit
3547+
state._checkpoint_queue.put(QueuedOperation(op2, None)) # real — stays in queue
35463548

35473549
for _ in range(50):
35483550
state._checkpoint_queue.put(QueuedOperation(None, None)) # trailing empties
@@ -3555,6 +3557,8 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit():
35553557
# The batch contains exactly: 1 leading empty + op_1 (limit=2 effective ops)
35563558
assert len(real_in_batch) == 1
35573559
assert real_in_batch[0].operation_update.operation_id == "op_1"
3558-
assert len(empty_in_batch) == 1 # Only the leading empty; trailing deferred to next batch
3560+
assert (
3561+
len(empty_in_batch) == 1
3562+
) # Only the leading empty; trailing deferred to next batch
35593563
# op_2 and trailing empties remain in the queue
35603564
assert state._checkpoint_queue.qsize() == 51

0 commit comments

Comments
 (0)