Skip to content

Commit e4f73c3

Browse files
committed
prevent a race condition in future completion
1 parent 1284d48 commit e4f73c3

1 file changed

Lines changed: 75 additions & 42 deletions

File tree

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected ConcurrencyOperation(
8282
this.toleratedFailureCount = toleratedFailureCount;
8383
this.operationIdGenerator = new OperationIdGenerator(getOperationId());
8484
this.rootContext = durableContext.createChildContext(getOperationId(), getName());
85-
this.consumerThreadListener = new AtomicReference<>(null);
85+
this.consumerThreadListener = new AtomicReference<>(new CompletableFuture<>());
8686
}
8787

8888
// ========== Template methods for subclasses ==========
@@ -138,15 +138,13 @@ protected <R> ChildContextOperation<R> enqueueItem(
138138
pendingQueue.add(childOp);
139139
logger.debug("Item enqueued {}", name);
140140
// notify the consumer thread a new item is available
141-
completeVacancyListenerIfSet();
141+
notifyConsumerThread();
142142
return childOp;
143143
}
144144

145-
private void completeVacancyListenerIfSet() {
145+
private void notifyConsumerThread() {
146146
synchronized (this) {
147-
if (consumerThreadListener.get() != null) {
148-
consumerThreadListener.get().complete(null);
149-
}
147+
consumerThreadListener.get().complete(null);
150148
}
151149
}
152150

@@ -159,6 +157,17 @@ protected void executeItems() {
159157

160158
Runnable consumer = () -> {
161159
while (true) {
160+
// Set a new future if it's completed so that it will be able to receive a notification of
161+
// new items when the thread is checking completion condition and processing
162+
// the queued items below.
163+
synchronized (this) {
164+
if (consumerThreadListener.get() != null
165+
&& consumerThreadListener.get().isDone()) {
166+
consumerThreadListener.set(new CompletableFuture<>());
167+
}
168+
}
169+
170+
// Process completion condition. Quit the loop if the condition is met.
162171
if (isOperationCompleted()) {
163172
return;
164173
}
@@ -167,14 +176,21 @@ protected void executeItems() {
167176
handleComplete(completionStatus);
168177
return;
169178
}
179+
180+
// process new items in the queue
170181
while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) {
171182
var next = pendingQueue.poll();
172183
runningChildren.add(next);
173184
logger.debug("Executing operation {}", next.getName());
174185
next.execute();
175186
}
176-
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
177-
// child may be null if the consumer thread is woken up due to a new item being added
187+
188+
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
189+
// immediately return null and repeat the above again
190+
var child = waitForChildCompletion(runningChildren);
191+
192+
// child may be null if the consumer thread is woken up due to new items added or completion condition
193+
// changed
178194
if (child != null) {
179195
if (runningChildren.contains(child)) {
180196
runningChildren.remove(child);
@@ -183,49 +199,62 @@ protected void executeItems() {
183199
throw new IllegalStateException("Unexpected completion: " + child);
184200
}
185201
}
186-
synchronized (this) {
187-
if (consumerThreadListener.get() != null
188-
&& consumerThreadListener.get().isDone()) {
189-
consumerThreadListener.set(null);
190-
}
191-
}
192202
}
193203
};
194204
// run consumer in the user thread pool, although it's not a real user thread
195205
runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT);
196206
}
197207

198-
private BaseDurableOperation waitForChildCompletion(
199-
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
208+
private BaseDurableOperation waitForChildCompletion(Set<BaseDurableOperation> runningChildren) {
200209
var threadContext = getCurrentThreadContext();
201-
CompletableFuture<Object> future;
210+
ArrayList<CompletableFuture<BaseDurableOperation>> futures = new ArrayList<>(runningChildren.stream()
211+
.map(BaseDurableOperation::getCompletionFuture)
212+
.toList());
202213

203-
synchronized (this) {
204-
// check again in synchronized block to prevent race conditions
205-
if (isOperationCompleted()) {
206-
return null;
207-
}
208-
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
209-
if (completionStatus != null) {
210-
return null;
214+
// always add the future to listen to the new items or condition changes. This might have been
215+
// completed during the period of consuming items from the queue.
216+
futures.add(consumerThreadListener.get());
217+
218+
// future will be completed immediately if any future of the list is already completed
219+
CompletableFuture<Object> future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
220+
221+
// skip deregistering the current thread if it is already completed
222+
AtomicBoolean skipDeregistration = new AtomicBoolean(false);
223+
long callerThreadId = Thread.currentThread().getId();
224+
225+
// attach the stage for registering thread without checking future status to avoid race condition
226+
future.thenRun(() -> {
227+
// at this point, future.isDone = true
228+
if (Thread.currentThread().getId() == callerThreadId) {
229+
// If the completion thread is the same as the consumer thread (immediately completed),
230+
// we don't want to deregister and register the thread.
231+
skipDeregistration.set(true);
232+
return;
211233
}
212-
ArrayList<CompletableFuture<BaseDurableOperation>> futures;
213-
futures = new ArrayList<>(runningChildren.stream()
214-
.map(BaseDurableOperation::getCompletionFuture)
215-
.toList());
216-
if (futures.size() < maxConcurrency) {
217-
// add a future to listen to the new items if there is a vacancy
218-
consumerThreadListener.compareAndSet(null, new CompletableFuture<>());
219-
futures.add(consumerThreadListener.get());
234+
235+
synchronized (this) {
236+
if (!skipDeregistration.get()) {
237+
registerActiveThread(threadContext.threadId());
238+
}
220239
}
240+
});
221241

222-
// future will be completed immediately if any future of the list is already completed
223-
future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
224-
// skip deregistering the current thread if there is more completed future to process
225-
if (!future.isDone()) {
226-
future.thenRun(() -> registerActiveThread(threadContext.threadId()));
227-
// Deregister the current thread to allow suspension
228-
executionManager.deregisterActiveThread(threadContext.threadId());
242+
// skip deregistration if future is always completed or skipDeregistration is true
243+
if (!future.isDone() || !skipDeregistration.get()) {
244+
// We tried best to skip deregistration/registration using thread id check above. However,
245+
// there is still a small window where the completionFuture is completed right after the check
246+
// and before the deregisterActiveThread call.
247+
248+
// Registering and deregistering threads might be competing for this lock.
249+
// If this deregistration thread skips deregistering, skipDeregistration will be set to true
250+
// so that the registration thread will skip registration
251+
synchronized (this) {
252+
if (!skipDeregistration.get() && !future.isDone()) {
253+
// Deregister the current thread to allow suspension
254+
executionManager.deregisterActiveThread(threadContext.threadId());
255+
} else {
256+
skipDeregistration.set(true);
257+
}
229258
}
230259
}
231260
return future.thenApply(o -> (BaseDurableOperation) o).join();
@@ -273,6 +302,8 @@ private ConcurrencyCompletionStatus canComplete(
273302
}
274303

275304
// All items finished — complete
305+
// This condition relies on isJoined, so the consumer will wake up and check this again when
306+
// isJoined is set to true.
276307
if (isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty()) {
277308
return ConcurrencyCompletionStatus.ALL_COMPLETED;
278309
}
@@ -299,8 +330,10 @@ protected void join() {
299330
+ ") exceeds the number of registered items (" + branches.size() + ")");
300331
}
301332
isJoined.set(true);
302-
// notify the execution thread this concurrency operation is joined
303-
completeVacancyListenerIfSet();
333+
334+
// Notify the consumer thread this concurrency operation is joined. Consumer thread need to check the
335+
// completion condition again.
336+
notifyConsumerThread();
304337
waitForOperationCompletion();
305338
}
306339

0 commit comments

Comments
 (0)