Skip to content

Commit e0018d3

Browse files
zhongkechennvasiu
authored andcommitted
fix polling for updates (aws#141)
1 parent 428670c commit e0018d3

2 files changed

Lines changed: 310 additions & 3 deletions

File tree

sdk/src/main/java/software/amazon/lambda/durable/execution/CheckpointBatcher.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,17 @@ CompletableFuture<Operation> pollForUpdate(String operationId, Duration delay) {
7070
.computeIfAbsent(operationId, k -> Collections.synchronizedList(new ArrayList<>()))
7171
.add(future);
7272
}
73-
checkpointApiRequestBatcher.submit(null, delay).thenCompose(v -> {
73+
pollForUpdateInternal(future, delay);
74+
return future;
75+
}
76+
77+
private CompletableFuture<Void> pollForUpdateInternal(CompletableFuture<Operation> future, Duration delay) {
78+
return checkpointApiRequestBatcher.submit(null, delay).thenCompose(v -> {
7479
if (future.isDone()) {
7580
return CompletableFuture.completedFuture(null);
7681
}
77-
return checkpointApiRequestBatcher.submit(null, delay);
82+
return pollForUpdateInternal(future, delay);
7883
});
79-
return future;
8084
}
8185

8286
/** Cancels all polling futures and waits for all pending checkpoint requests to complete */
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable.execution;
4+
5+
import static org.junit.jupiter.api.Assertions.*;
6+
import static org.mockito.ArgumentMatchers.*;
7+
import static org.mockito.Mockito.*;
8+
9+
import java.time.Duration;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.concurrent.TimeUnit;
13+
import java.util.concurrent.TimeoutException;
14+
import org.junit.jupiter.api.BeforeEach;
15+
import org.junit.jupiter.api.Test;
16+
import software.amazon.awssdk.services.lambda.model.CheckpointDurableExecutionResponse;
17+
import software.amazon.awssdk.services.lambda.model.CheckpointUpdatedExecutionState;
18+
import software.amazon.awssdk.services.lambda.model.GetDurableExecutionStateResponse;
19+
import software.amazon.awssdk.services.lambda.model.Operation;
20+
import software.amazon.awssdk.services.lambda.model.OperationAction;
21+
import software.amazon.awssdk.services.lambda.model.OperationStatus;
22+
import software.amazon.awssdk.services.lambda.model.OperationType;
23+
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
24+
import software.amazon.lambda.durable.DurableConfig;
25+
import software.amazon.lambda.durable.client.DurableExecutionClient;
26+
27+
class CheckpointBatcherTest {
28+
29+
private DurableConfig config;
30+
private DurableExecutionClient client;
31+
private CheckpointBatcher batcher;
32+
private List<Operation> callbackOperations;
33+
34+
@BeforeEach
35+
void setUp() {
36+
client = mock(DurableExecutionClient.class);
37+
config = DurableConfig.builder()
38+
.withDurableExecutionClient(client)
39+
.withCheckpointDelay(Duration.ofMillis(50))
40+
.withPollingInterval(Duration.ofMillis(50))
41+
.build();
42+
43+
callbackOperations = new ArrayList<>();
44+
batcher = new CheckpointBatcher(config, "arn:test", "token-1", callbackOperations::addAll);
45+
}
46+
47+
@Test
48+
void checkpoint_sendsUpdateAndReturnsCompletedFuture() throws Exception {
49+
var update = OperationUpdate.builder()
50+
.id("op-1")
51+
.type(OperationType.STEP)
52+
.action(OperationAction.START)
53+
.build();
54+
55+
when(client.checkpoint(anyString(), anyString(), anyList()))
56+
.thenReturn(CheckpointDurableExecutionResponse.builder()
57+
.checkpointToken("token-2")
58+
.build());
59+
60+
var future = batcher.checkpoint(update);
61+
62+
// Wait for batch to flush
63+
future.get(200, TimeUnit.MILLISECONDS);
64+
65+
verify(client).checkpoint(eq("arn:test"), eq("token-1"), anyList());
66+
assertTrue(future.isDone());
67+
}
68+
69+
@Test
70+
void pollForUpdate_completesWhenOperationReturned() throws Exception {
71+
var operation = Operation.builder()
72+
.id("op-1")
73+
.type(OperationType.STEP)
74+
.status(OperationStatus.SUCCEEDED)
75+
.build();
76+
77+
when(client.checkpoint(anyString(), anyString(), anyList()))
78+
.thenReturn(CheckpointDurableExecutionResponse.builder()
79+
.checkpointToken("token-2")
80+
.newExecutionState(CheckpointUpdatedExecutionState.builder()
81+
.operations(List.of(operation))
82+
.build())
83+
.build());
84+
85+
var future = batcher.pollForUpdate("op-1");
86+
87+
assertFalse(future.isDone());
88+
89+
// Wait for polling to trigger checkpoint
90+
var result = future.get(300, TimeUnit.MILLISECONDS);
91+
92+
assertEquals(operation, result);
93+
assertEquals(1, callbackOperations.size());
94+
}
95+
96+
@Test
97+
void pollForUpdate_doesNotCompleteWhenDifferentOperationReturned() throws Exception {
98+
var operation = Operation.builder()
99+
.id("op-2")
100+
.type(OperationType.STEP)
101+
.status(OperationStatus.SUCCEEDED)
102+
.build();
103+
104+
when(client.checkpoint(anyString(), anyString(), anyList()))
105+
.thenReturn(CheckpointDurableExecutionResponse.builder()
106+
.checkpointToken("token-2")
107+
.newExecutionState(CheckpointUpdatedExecutionState.builder()
108+
.operations(List.of(operation))
109+
.build())
110+
.build());
111+
112+
var future = batcher.pollForUpdate("op-1");
113+
114+
// Should timeout since op-1 never returned
115+
assertThrows(TimeoutException.class, () -> future.get(200, TimeUnit.MILLISECONDS));
116+
}
117+
118+
@Test
119+
void pollForUpdate_handlesMultiplePollers() throws Exception {
120+
var operation = Operation.builder()
121+
.id("op-1")
122+
.type(OperationType.STEP)
123+
.status(OperationStatus.SUCCEEDED)
124+
.build();
125+
126+
when(client.checkpoint(anyString(), anyString(), anyList()))
127+
.thenReturn(CheckpointDurableExecutionResponse.builder()
128+
.checkpointToken("token-2")
129+
.newExecutionState(CheckpointUpdatedExecutionState.builder()
130+
.operations(List.of(operation))
131+
.build())
132+
.build());
133+
134+
var future1 = batcher.pollForUpdate("op-1");
135+
var future2 = batcher.pollForUpdate("op-1");
136+
var future3 = batcher.pollForUpdate("op-1");
137+
138+
var result1 = future1.get(300, TimeUnit.MILLISECONDS);
139+
var result2 = future2.get(300, TimeUnit.MILLISECONDS);
140+
var result3 = future3.get(300, TimeUnit.MILLISECONDS);
141+
142+
assertEquals(operation, result1);
143+
assertEquals(operation, result2);
144+
assertEquals(operation, result3);
145+
}
146+
147+
@Test
148+
void shutdown_completesAllPendingPollersWithException() {
149+
var future1 = batcher.pollForUpdate("op-1");
150+
var future2 = batcher.pollForUpdate("op-2");
151+
152+
batcher.shutdown();
153+
154+
assertTrue(future1.isCompletedExceptionally());
155+
assertTrue(future2.isCompletedExceptionally());
156+
157+
assertThrows(Exception.class, future1::join);
158+
assertThrows(Exception.class, future2::join);
159+
}
160+
161+
@Test
162+
void shutdown_waitsForPendingCheckpoints() throws Exception {
163+
when(client.checkpoint(anyString(), anyString(), anyList()))
164+
.thenReturn(CheckpointDurableExecutionResponse.builder()
165+
.checkpointToken("token-2")
166+
.build());
167+
168+
var future = batcher.checkpoint(OperationUpdate.builder()
169+
.id("op-1")
170+
.action(OperationAction.START)
171+
.type(OperationType.STEP)
172+
.build());
173+
174+
batcher.shutdown();
175+
176+
assertTrue(future.isDone());
177+
verify(client, atLeastOnce()).checkpoint(anyString(), anyString(), anyList());
178+
}
179+
180+
@Test
181+
void fetchAllPages_retrievesAllOperations() {
182+
var op1 = Operation.builder().id("op-1").build();
183+
var op2 = Operation.builder().id("op-2").build();
184+
var op3 = Operation.builder().id("op-3").build();
185+
186+
when(client.getExecutionState(eq("arn:test"), eq("token-1"), eq("marker-1")))
187+
.thenReturn(GetDurableExecutionStateResponse.builder()
188+
.operations(List.of(op2))
189+
.nextMarker("marker-2")
190+
.build());
191+
192+
when(client.getExecutionState(eq("arn:test"), eq("token-1"), eq("marker-2")))
193+
.thenReturn(GetDurableExecutionStateResponse.builder()
194+
.operations(List.of(op3))
195+
.nextMarker(null)
196+
.build());
197+
198+
var state = CheckpointUpdatedExecutionState.builder()
199+
.operations(List.of(op1))
200+
.nextMarker("marker-1")
201+
.build();
202+
203+
var result = batcher.fetchAllPages(state);
204+
205+
assertEquals(3, result.size());
206+
assertEquals("op-1", result.get(0).id());
207+
assertEquals("op-2", result.get(1).id());
208+
assertEquals("op-3", result.get(2).id());
209+
}
210+
211+
@Test
212+
void fetchAllPages_handlesNullState() {
213+
var result = batcher.fetchAllPages(null);
214+
215+
assertEquals(0, result.size());
216+
verify(client, never()).getExecutionState(anyString(), anyString(), anyString());
217+
}
218+
219+
@Test
220+
void fetchAllPages_handlesEmptyMarker() {
221+
var state = CheckpointUpdatedExecutionState.builder()
222+
.operations(List.of(Operation.builder().id("op-1").build()))
223+
.nextMarker("")
224+
.build();
225+
226+
var result = batcher.fetchAllPages(state);
227+
228+
assertEquals(1, result.size());
229+
verify(client, never()).getExecutionState(anyString(), anyString(), anyString());
230+
}
231+
232+
@Test
233+
void checkpoint_updatesCheckpointToken() throws Exception {
234+
when(client.checkpoint(anyString(), eq("token-1"), anyList()))
235+
.thenReturn(CheckpointDurableExecutionResponse.builder()
236+
.checkpointToken("token-2")
237+
.build());
238+
239+
when(client.checkpoint(anyString(), eq("token-2"), anyList()))
240+
.thenReturn(CheckpointDurableExecutionResponse.builder()
241+
.checkpointToken("token-3")
242+
.build());
243+
244+
batcher.checkpoint(OperationUpdate.builder()
245+
.id("op-1")
246+
.type(OperationType.STEP)
247+
.action(OperationAction.SUCCEED)
248+
.build())
249+
.get(200, TimeUnit.MILLISECONDS);
250+
251+
batcher.checkpoint(OperationUpdate.builder()
252+
.id("op-2")
253+
.type(OperationType.STEP)
254+
.action(OperationAction.START)
255+
.build())
256+
.get(200, TimeUnit.MILLISECONDS);
257+
258+
verify(client).checkpoint(eq("arn:test"), eq("token-1"), anyList());
259+
verify(client).checkpoint(eq("arn:test"), eq("token-2"), anyList());
260+
}
261+
262+
@Test
263+
void pollForUpdate_withCustomDelay() throws Exception {
264+
var operation =
265+
Operation.builder().id("op-1").status(OperationStatus.SUCCEEDED).build();
266+
267+
when(client.checkpoint(anyString(), anyString(), anyList()))
268+
.thenReturn(CheckpointDurableExecutionResponse.builder()
269+
.checkpointToken("token-2")
270+
.newExecutionState(CheckpointUpdatedExecutionState.builder()
271+
.operations(List.of(operation))
272+
.build())
273+
.build());
274+
275+
var future = batcher.pollForUpdate("op-1", Duration.ofMillis(100));
276+
277+
var result = future.get(300, TimeUnit.MILLISECONDS);
278+
279+
assertEquals(operation, result);
280+
}
281+
282+
@Test
283+
void checkpoint_filtersNullUpdates() throws Exception {
284+
when(client.checkpoint(anyString(), anyString(), anyList()))
285+
.thenReturn(CheckpointDurableExecutionResponse.builder()
286+
.checkpointToken("token-2")
287+
.build());
288+
289+
// Submit null (from polling) and real update
290+
batcher.pollForUpdate("op-1");
291+
batcher.checkpoint(OperationUpdate.builder()
292+
.id("op-2")
293+
.type(OperationType.STEP)
294+
.action(OperationAction.START)
295+
.build())
296+
.get(200, TimeUnit.MILLISECONDS);
297+
298+
verify(client).checkpoint(eq("arn:test"), eq("token-1"), argThat(list -> {
299+
// Should only contain non-null update
300+
return list.stream().noneMatch(u -> u == null);
301+
}));
302+
}
303+
}

0 commit comments

Comments
 (0)