22// SPDX-License-Identifier: Apache-2.0
33package com .amazonaws .lambda .durable .execution ;
44
5+ import com .amazonaws .lambda .durable .DurableConfig ;
56import com .amazonaws .lambda .durable .client .DurableExecutionClient ;
7+ import java .time .Duration ;
68import java .util .ArrayList ;
9+ import java .util .Collections ;
710import java .util .List ;
11+ import java .util .Map ;
812import java .util .Objects ;
913import java .util .concurrent .BlockingQueue ;
1014import java .util .concurrent .CompletableFuture ;
15+ import java .util .concurrent .ConcurrentHashMap ;
1116import java .util .concurrent .LinkedBlockingQueue ;
1217import java .util .concurrent .atomic .AtomicBoolean ;
1318import java .util .function .Consumer ;
@@ -35,41 +40,57 @@ class CheckpointBatcher {
3540 private final String durableExecutionArn ;
3641 private final DurableExecutionClient client ;
3742 private final BlockingQueue <CheckpointRequest > queue = new LinkedBlockingQueue <>();
38- private final AtomicBoolean isProcessing = new AtomicBoolean (false );
43+ private final AtomicBoolean isRunning = new AtomicBoolean (true );
44+ private final Duration pollingInterval ;
45+ private final Map <String , List <CompletableFuture <Operation >>> pollingFutures = new ConcurrentHashMap <>();
3946 private String checkpointToken ;
4047
4148 record CheckpointRequest (OperationUpdate update , CompletableFuture <Void > completion ) {}
4249
4350 CheckpointBatcher (
44- DurableExecutionClient client ,
51+ DurableConfig config ,
4552 String durableExecutionArn ,
4653 String checkpointToken ,
4754 Consumer <List <Operation >> callback ) {
48- this .client = client ;
55+ this .client = config . getDurableExecutionClient () ;
4956 this .durableExecutionArn = durableExecutionArn ;
5057 this .callback = callback ;
5158 this .checkpointToken = checkpointToken ;
59+ this .pollingInterval = config .getPollingInterval ();
60+
61+ InternalExecutor .INSTANCE .execute (this ::processQueue );
5262 }
5363
5464 CompletableFuture <Void > checkpoint (OperationUpdate update ) {
55- logger .debug (
56- "Checkpoint request received: Action {}" ,
57- update != null ? update .action () : "NULL (Checkpoint request)" );
65+ logger .debug ("Checkpoint request received: Action {}" , update .action ());
5866 var future = new CompletableFuture <Void >();
5967 queue .add (new CheckpointRequest (update , future ));
68+ return future ;
69+ }
6070
61- if (isProcessing .compareAndSet (false , true )) {
62- InternalExecutor .INSTANCE .execute (this ::processQueue );
63- }
64-
71+ CompletableFuture <Operation > pollForUpdate (String operationId ) {
72+ var future = new CompletableFuture <Operation >();
73+ pollingFutures
74+ .computeIfAbsent (operationId , k -> Collections .synchronizedList (new ArrayList <>()))
75+ .add (future );
6576 return future ;
6677 }
6778
6879 void shutdown () {
6980 var remaining = new ArrayList <CheckpointRequest >();
81+ isRunning .set (false );
7082 queue .drainTo (remaining );
83+
84+ // fail the checkpoint requests
7185 remaining .forEach (
7286 req -> req .completion ().completeExceptionally (new IllegalStateException ("CheckpointManager shutdown" )));
87+
88+ // fail the pollers
89+ for (var operationId : pollingFutures .keySet ()) {
90+ pollingFutures
91+ .remove (operationId )
92+ .forEach (f -> f .completeExceptionally (new IllegalStateException ("CheckpointManager shutdown" )));
93+ }
7394 }
7495
7596 public List <Operation > fetchAllPages (List <Operation > initialOperations , String nextMarker ) {
@@ -79,25 +100,32 @@ public List<Operation> fetchAllPages(List<Operation> initialOperations, String n
79100 }
80101 while (nextMarker != null && !nextMarker .isEmpty ()) {
81102 var response = client .getExecutionState (durableExecutionArn , checkpointToken , nextMarker );
82- logger .debug ("DAR getExecutionState called: {}." , response );
103+ logger .debug ("Durable API getExecutionState called: {}." , response );
83104 operations .addAll (response .operations ());
84105 nextMarker = response .nextMarker ();
85106 }
86107 return operations ;
87108 }
88109
89110 private void processQueue () {
90- try {
91- var batch = collectBatch ();
92- if (!batch .isEmpty ()) {
111+ while (isRunning .get ()) {
112+ if (queue .isEmpty () && pollingFutures .isEmpty ()) {
113+ // nothing to process
114+ try {
115+ Thread .sleep (pollingInterval .toMillis ());
116+ } catch (InterruptedException ignored ) {
117+ }
118+ }
119+ try {
120+ var batch = collectBatch ();
93121 // Filter out null updates (empty checkpoints for polling)
94122 var updates = batch .stream ()
95123 .map (CheckpointRequest ::update )
96124 .filter (Objects ::nonNull )
97125 .toList ();
98126
99127 var response = client .checkpoint (durableExecutionArn , checkpointToken , updates );
100- logger .debug ("DAR checkpointDurableExecution called: {}." , response );
128+ logger .debug ("Durable API checkpointDurableExecution called: {}." , response );
101129
102130 // Notify callback of completion
103131 // TODO: sam local backend returns no new execution state when called with zero
@@ -112,19 +140,21 @@ private void processQueue() {
112140 if (!operations .isEmpty ()) {
113141 callback .accept (operations );
114142 }
143+
144+ for (var operation : operations ) {
145+ var pollers = pollingFutures .remove (operation .id ());
146+ if (pollers != null ) {
147+ pollers .forEach (poller -> poller .complete (operation ));
148+ }
149+ }
115150 }
116151
152+ // checkpoint operation completed
117153 batch .forEach (req -> req .completion ().complete (null ));
118- }
119- } catch (Exception e ) {
120- var batch = new ArrayList <CheckpointRequest >();
121- queue .drainTo (batch );
122- batch .forEach (req -> req .completion ().completeExceptionally (e ));
123- } finally {
124- isProcessing .set (false );
125-
126- if (!queue .isEmpty () && isProcessing .compareAndSet (false , true )) {
127- InternalExecutor .INSTANCE .execute (this ::processQueue );
154+ } catch (Throwable e ) {
155+ var batch = new ArrayList <CheckpointRequest >();
156+ queue .drainTo (batch );
157+ batch .forEach (req -> req .completion ().completeExceptionally (e ));
128158 }
129159 }
130160 }
@@ -134,14 +164,16 @@ private List<CheckpointRequest> collectBatch() {
134164 var currentSize = 0 ;
135165
136166 CheckpointRequest req ;
137- while ((req = queue .poll ()) != null ) {
167+ while ((req = queue .peek ()) != null ) {
138168 var itemSize = estimateSize (req .update ());
139169
170+ // will include the big item in the batch if the batch is empty
140171 if (currentSize + itemSize > MAX_BATCH_SIZE_BYTES && !batch .isEmpty ()) {
141- queue .add (req );
142172 break ;
143173 }
144174
175+ queue .remove ();
176+
145177 batch .add (req );
146178 currentSize += itemSize ;
147179 }
0 commit comments