Skip to content

Commit 34c79a8

Browse files
committed
add tests for batcher
1 parent d396bec commit 34c79a8

3 files changed

Lines changed: 226 additions & 24 deletions

File tree

sdk/src/main/java/com/amazonaws/lambda/durable/execution/AsyncBatcher.java renamed to sdk/src/main/java/com/amazonaws/lambda/durable/execution/ApiRequestBatcher.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,20 @@
1212
import java.util.stream.Collectors;
1313

1414
/**
15-
* This class simplifies automatic batching of async requests. Your code deals with individual items, the service you
16-
* are calling asynchronously has a cheaper batch API. You are willing to trade some latency by waiting for more calls
17-
* to arrive to group them in a single batch call into the service. The batch call will be made when either a full batch
18-
* is built, too much time has passed, or size limits are reached. This class builds a single batch at a time with
19-
* thread-safe synchronization: - There is no batch yet. - First call arrives. Create a batch with one item in it, start
20-
* a timer. No call to service is made yet. - More calls arrive. They get added to the same batch if size limits allow.
21-
* - Either the batch is full, the timer has elapsed, or size limits are reached. Send the batch request. Now a new
22-
* batch can now be built. - If entire batch call fails, each call will fail. - If batch call succeeded, outcome is
23-
* analyzed one by one to complete results of each call. When you extend this class, you are expected to implement the
24-
* actual batch operation and to expose a public method to perform a single action. The batcher includes comprehensive
25-
* metrics tracking for performance monitoring.
15+
* This class simplifies automatic batching of api requests. The individual request items will be grouped if the service
16+
* has a cheaper batch API, and we want to trade some latency by waiting for more calls to arrive. The batch call will
17+
* be made when either a full batch is built, too much time has passed, or size limits are reached. This class builds a
18+
* single batch at a time with thread-safe synchronization: - There is no batch yet. - First call arrives. Create a
19+
* batch with one item in it, start a timer. No call to service is made yet. - More calls arrive. They get added to the
20+
* same batch if size limits allow. - Either the batch is full, the timer has elapsed, or size limits are reached. Send
21+
* the batch request. Now a new batch can now be built. - If entire batch call fails, each call will fail. - If batch
22+
* call succeeded, outcome is analyzed one by one to complete results of each call. When you extend this class, you are
23+
* expected to implement the actual batch operation and to expose a public method to perform a single action. The
24+
* batcher includes comprehensive metrics tracking for performance monitoring.
2625
*
2726
* @param <T> Input of every call
2827
*/
29-
public class AsyncBatcher<T> {
28+
public class ApiRequestBatcher<T> {
3029

3130
/** Maximum time to wait before flushing a batch */
3231
private final Duration maxDelay;
@@ -115,14 +114,14 @@ private Void failAllItems(Throwable wrappedCause) {
115114
private CompletableFuture<Void> batchFlushFuture;
116115

117116
/**
118-
* Creates a new AsyncBatcher with the specified configuration.
117+
* Creates a new ApiRequestBatcher with the specified configuration.
119118
*
120119
* @param maxDelay Maximum time to wait before flushing a batch
121120
* @param maxBatchSize Maximum number of items per batch
122121
* @param maxBatchBinarySizeInBytes Maximum total size in bytes for all items in a batch
123122
* @param itemSizeInBytesProvider Function to calculate the size in bytes of each item
124123
*/
125-
public AsyncBatcher(
124+
public ApiRequestBatcher(
126125
Duration maxDelay,
127126
int maxBatchSize,
128127
int maxBatchBinarySizeInBytes,

sdk/src/main/java/com/amazonaws/lambda/durable/execution/CheckpointBatcher.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
*/
3232
class CheckpointBatcher {
3333
private static final int MAX_BATCH_SIZE_BYTES = 750 * 1024; // 750KB
34+
private static final int MAX_BATCH_SIZE = 100; // max updates in one batch
3435
private static final Logger logger = LoggerFactory.getLogger(CheckpointBatcher.class);
3536

3637
private final Consumer<List<Operation>> callback;
@@ -39,9 +40,8 @@ class CheckpointBatcher {
3940
private final AtomicBoolean isRunning = new AtomicBoolean(true);
4041
private final Duration pollingInterval;
4142
private final Map<String, List<CompletableFuture<Operation>>> pollingFutures = new ConcurrentHashMap<>();
42-
private final AsyncBatcher<OperationUpdate> checkpointAsyncBatcher;
43+
private final ApiRequestBatcher<OperationUpdate> checkpointApiRequestBatcher;
4344
private String checkpointToken;
44-
private final Object batchCheckpointLock = new Object();
4545

4646
CheckpointBatcher(
4747
DurableConfig config,
@@ -53,9 +53,9 @@ class CheckpointBatcher {
5353
this.callback = callback;
5454
this.checkpointToken = checkpointToken;
5555
this.pollingInterval = config.getPollingInterval();
56-
this.checkpointAsyncBatcher = new AsyncBatcher<>(
56+
this.checkpointApiRequestBatcher = new ApiRequestBatcher<>(
5757
config.getCheckpointDelay(),
58-
Integer.MAX_VALUE,
58+
MAX_BATCH_SIZE,
5959
MAX_BATCH_SIZE_BYTES,
6060
CheckpointBatcher::estimateSize,
6161
this::doBatchAction);
@@ -65,7 +65,7 @@ class CheckpointBatcher {
6565

6666
CompletableFuture<Void> checkpoint(OperationUpdate update) {
6767
logger.debug("Checkpoint request received: Action {}", update.action());
68-
return checkpointAsyncBatcher.doAction(update);
68+
return checkpointApiRequestBatcher.doAction(update);
6969
}
7070

7171
CompletableFuture<Operation> pollForUpdate(String operationId) {
@@ -76,6 +76,7 @@ CompletableFuture<Operation> pollForUpdate(String operationId) {
7676
pollingFutures
7777
.computeIfAbsent(operationId, k -> Collections.synchronizedList(new ArrayList<>()))
7878
.add(future);
79+
pollingFutures.notifyAll();
7980
}
8081
return future;
8182
}
@@ -111,17 +112,30 @@ private void processQueue() {
111112
// background thread running
112113
while (isRunning.get()) {
113114
if (!pollingFutures.isEmpty()) {
115+
// If pollers exist, poll for updates periodically.
116+
// When wake from wait, sleep for an initial delay before calling the API
117+
try {
118+
Thread.sleep(pollingInterval.toMillis());
119+
} catch (InterruptedException ignored) {
120+
// ignored
121+
}
114122
doBatchAction(List.of()).join();
115-
}
116-
try {
117-
Thread.sleep(pollingInterval.toMillis());
118-
} catch (InterruptedException ignored) {
123+
} else {
124+
// if empty, waiting for new pollers
125+
synchronized (pollingFutures) {
126+
try {
127+
pollingFutures.wait(pollingInterval.toMillis());
128+
} catch (InterruptedException e) {
129+
Thread.currentThread().interrupt();
130+
break;
131+
}
132+
}
119133
}
120134
}
121135
}
122136

123137
protected CompletableFuture<Void> doBatchAction(List<OperationUpdate> updates) {
124-
// doBatchAction will be called from the polling thread and also from AsyncBatcher.
138+
// doBatchAction will be called from the polling thread and also from ApiRequestBatcher.
125139
// Use synchronized here to make sure no concurrent checkpoint API calls
126140
synchronized (pollingFutures) {
127141
logger.debug("Calling durable API checkpointDurableExecution with {} updates", updates.size());
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable.execution;
4+
5+
import static org.junit.jupiter.api.Assertions.assertEquals;
6+
import static org.junit.jupiter.api.Assertions.assertFalse;
7+
import static org.junit.jupiter.api.Assertions.assertTrue;
8+
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.Mockito.mock;
10+
import static org.mockito.Mockito.never;
11+
import static org.mockito.Mockito.timeout;
12+
import static org.mockito.Mockito.verify;
13+
import static org.mockito.Mockito.when;
14+
15+
import java.time.Clock;
16+
import java.time.Duration;
17+
import java.time.Instant;
18+
import java.time.ZoneOffset;
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
import java.util.concurrent.CompletableFuture;
22+
import java.util.concurrent.ExecutionException;
23+
import java.util.concurrent.TimeUnit;
24+
import java.util.function.Function;
25+
import org.junit.jupiter.api.BeforeEach;
26+
import org.junit.jupiter.api.Test;
27+
28+
class ApiRequestBatcherTest {
29+
private static final Duration MAX_DELAY_MILLIS = Duration.ofMillis(100);
30+
private static final int MAX_BATCH_SIZE = 3;
31+
private static final int MAX_BATCH_BINARY_SIZE_IN_BYTES = 100;
32+
33+
private static class Input {}
34+
35+
private Input input;
36+
private Clock fixedClock;
37+
private ApiRequestBatcher<Input> cut;
38+
private Function<List<Input>, CompletableFuture<Void>> doBatchAction;
39+
private CompletableFuture<Void> batchResultFuture;
40+
41+
@BeforeEach
42+
void setUp() {
43+
input = mock(Input.class);
44+
doBatchAction = mock();
45+
batchResultFuture = new CompletableFuture<>();
46+
fixedClock = Clock.fixed(Instant.now(), ZoneOffset.UTC);
47+
cut = new ApiRequestBatcher<>(
48+
MAX_DELAY_MILLIS, MAX_BATCH_SIZE, MAX_BATCH_BINARY_SIZE_IN_BYTES, item -> 0, doBatchAction);
49+
50+
when(doBatchAction.apply(any())).thenReturn(batchResultFuture);
51+
}
52+
53+
@Test
54+
void whenSingleActionPerformed_anUncompletedFutureIsReturned() {
55+
CompletableFuture<Void> resultFuture = cut.doAction(input);
56+
57+
verify(doBatchAction, never()).apply(any());
58+
assertFalse(resultFuture.isDone());
59+
}
60+
61+
@Test
62+
void whenMultipleActionsPerformedBelowMaxBatchSize_anUncompletedFutureIsReturnedEachTime() {
63+
List<CompletableFuture<Void>> resultFutures = new ArrayList<>();
64+
for (int i = 0; i < MAX_BATCH_SIZE - 1; i++) {
65+
resultFutures.add(cut.doAction(input));
66+
}
67+
68+
verify(doBatchAction, never()).apply(any());
69+
assertTrue(resultFutures.stream().noneMatch(CompletableFuture::isDone));
70+
}
71+
72+
@Test
73+
void whenMultipleActionsPerformedMatchingMaxBatchSize_batchInvokeIsPerformed() {
74+
List<CompletableFuture<Void>> resultFutures = new ArrayList<>();
75+
for (int i = 0; i < MAX_BATCH_SIZE; i++) {
76+
resultFutures.add(cut.doAction(input));
77+
}
78+
79+
verify(doBatchAction).apply(any());
80+
assertTrue(resultFutures.stream().noneMatch(CompletableFuture::isDone));
81+
}
82+
83+
@Test
84+
void whenBatchInvokeThrows_allFuturesCompleteWithThatException() throws InterruptedException {
85+
CompletableFuture<Void> resultFuture1 = cut.doAction(input);
86+
CompletableFuture<Void> resultFuture2 = cut.doAction(input);
87+
CompletableFuture<Void> resultFuture3 = cut.doAction(input);
88+
89+
assertFalse(resultFuture1.isDone());
90+
assertFalse(resultFuture2.isDone());
91+
assertFalse(resultFuture3.isDone());
92+
93+
Throwable batchCause = mock(Throwable.class);
94+
batchResultFuture.completeExceptionally(batchCause);
95+
96+
assertTrue(resultFuture1.isCompletedExceptionally());
97+
assertTrue(resultFuture2.isCompletedExceptionally());
98+
assertTrue(resultFuture3.isCompletedExceptionally());
99+
100+
assertEquals(batchCause, getFutureCause(resultFuture1));
101+
assertEquals(batchCause, getFutureCause(resultFuture2));
102+
assertEquals(batchCause, getFutureCause(resultFuture3));
103+
}
104+
105+
@Test
106+
void whenBatchInvokeReturnsOutcome_allFuturesCompleteSuccessfully() {
107+
Input input1 = mock(Input.class);
108+
Input input2 = mock(Input.class);
109+
Input input3 = mock(Input.class);
110+
111+
CompletableFuture<Void> resultFuture1 = cut.doAction(input1);
112+
CompletableFuture<Void> resultFuture2 = cut.doAction(input2);
113+
CompletableFuture<Void> resultFuture3 = cut.doAction(input3);
114+
115+
assertFalse(resultFuture1.isDone());
116+
assertFalse(resultFuture2.isDone());
117+
assertFalse(resultFuture3.isDone());
118+
119+
batchResultFuture.complete(null);
120+
121+
assertTrue(resultFuture1.isDone());
122+
assertTrue(resultFuture2.isDone());
123+
assertTrue(resultFuture3.isDone());
124+
125+
assertFalse(resultFuture1.isCompletedExceptionally());
126+
assertFalse(resultFuture2.isCompletedExceptionally());
127+
assertFalse(resultFuture3.isCompletedExceptionally());
128+
}
129+
130+
@Test
131+
void testDoAction_whenCannotAddItemDueToBinarySizeConstraint_thenFlushCurrentBatchAndCreateNewOne() {
132+
var cut = new ApiRequestBatcher<>(
133+
MAX_DELAY_MILLIS,
134+
MAX_BATCH_SIZE,
135+
MAX_BATCH_BINARY_SIZE_IN_BYTES,
136+
item -> MAX_BATCH_BINARY_SIZE_IN_BYTES,
137+
doBatchAction);
138+
List<CompletableFuture<Void>> resultFutures = new ArrayList<>();
139+
140+
resultFutures.add(cut.doAction(input));
141+
resultFutures.add(cut.doAction(input));
142+
143+
verify(doBatchAction).apply(any());
144+
145+
assertTrue(resultFutures.stream().noneMatch(CompletableFuture::isDone));
146+
}
147+
148+
@Test
149+
void whenTimerFires_batchIsProcessed() {
150+
var timerCut = new ApiRequestBatcher<>(
151+
Duration.ofMillis(1), MAX_BATCH_SIZE, MAX_BATCH_BINARY_SIZE_IN_BYTES, item -> 0, doBatchAction);
152+
153+
CompletableFuture<Void> resultFuture = timerCut.doAction(input);
154+
155+
// Wait for the timeout to trigger
156+
CompletableFuture.delayedExecutor(10, TimeUnit.MILLISECONDS).execute(() -> {});
157+
158+
verify(doBatchAction, timeout(50)).apply(any());
159+
assertFalse(resultFuture.isDone());
160+
}
161+
162+
@Test
163+
void whenBatchInvokeThrowsCompletionException_allFuturesCompleteWithUnwrappedCause() throws InterruptedException {
164+
CompletableFuture<Void> resultFuture1 = cut.doAction(input);
165+
CompletableFuture<Void> resultFuture2 = cut.doAction(input);
166+
CompletableFuture<Void> resultFuture3 = cut.doAction(input);
167+
168+
RuntimeException rootCause = new RuntimeException("Root cause");
169+
batchResultFuture.completeExceptionally(rootCause);
170+
171+
assertTrue(resultFuture1.isCompletedExceptionally());
172+
assertTrue(resultFuture2.isCompletedExceptionally());
173+
assertTrue(resultFuture3.isCompletedExceptionally());
174+
175+
// Should get unwrapped root cause, not the CompletionException wrapper
176+
assertEquals(rootCause, getFutureCause(resultFuture1));
177+
assertEquals(rootCause, getFutureCause(resultFuture2));
178+
assertEquals(rootCause, getFutureCause(resultFuture3));
179+
}
180+
181+
private Throwable getFutureCause(CompletableFuture<?> failedFuture) throws InterruptedException {
182+
try {
183+
failedFuture.get();
184+
return null;
185+
} catch (ExecutionException cause) {
186+
return cause.getCause();
187+
}
188+
}
189+
}

0 commit comments

Comments
 (0)