Skip to content

Commit 6b2e0f2

Browse files
committed
prevent a race condition in future completion
1 parent c3dc73d commit 6b2e0f2

2 files changed

Lines changed: 89 additions & 71 deletions

File tree

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.List;
77
import java.util.Objects;
88
import java.util.concurrent.CompletableFuture;
9+
import java.util.concurrent.atomic.AtomicBoolean;
910
import java.util.concurrent.atomic.AtomicReference;
1011
import org.slf4j.Logger;
1112
import org.slf4j.LoggerFactory;
@@ -179,21 +180,31 @@ protected Operation waitForOperationCompletion() {
179180
// It's important that we synchronize access to the future. Otherwise, a race condition could happen if the
180181
// completionFuture is completed by a user thread (a step or child context thread) when the execution here
181182
// is between `isOperationCompleted` and `thenRun`.
182-
synchronized (completionFuture) {
183-
if (!isOperationCompleted()) {
184-
// Operation not done yet
185-
logger.trace(
186-
"deregistering thread {} when waiting for operation {} ({}) to complete ({})",
187-
threadContext.threadId(),
188-
getOperation(),
189-
getType(),
190-
completionFuture);
191-
192-
// Add a completion stage to completionFuture so that when the completionFuture is completed,
193-
// it will register the current Context thread synchronously to make sure it is always registered
194-
// strictly before the execution thread (Step or child context) is deregistered.
195-
completionFuture.thenRun(() -> registerActiveThread(threadContext.threadId()));
183+
if (!isOperationCompleted()) {
184+
// Operation not done yet
185+
logger.trace(
186+
"deregistering thread {} when waiting for operation {} ({}) to complete ({})",
187+
threadContext.threadId(),
188+
getOperation(),
189+
getType(),
190+
completionFuture);
191+
192+
// Add a completion stage to completionFuture so that when the completionFuture is completed,
193+
// it will register the current Context thread synchronously to make sure it is always registered
194+
// strictly before the execution thread (Step or child context) is deregistered.
195+
AtomicBoolean alreadyCompleted = new AtomicBoolean(false);
196+
long callerThreadId = Thread.currentThread().getId();
197+
completionFuture.thenRun(() -> {
198+
if (Thread.currentThread().getId() == callerThreadId) {
199+
// If the calling thread is same as the thread that initiated the wait, meaning the future is
200+
// already completed, we should skip deregister/register the thread.
201+
alreadyCompleted.set(true);
202+
return;
203+
}
204+
registerActiveThread(threadContext.threadId());
205+
});
196206

207+
if (!alreadyCompleted.get()) {
197208
// Deregister the current thread to allow suspension
198209
executionManager.deregisterActiveThread(threadContext.threadId());
199210
}
@@ -266,8 +277,11 @@ public void onCheckpointComplete(Operation operation) {
266277
// This method handles only terminal status updates. Override this method if a DurableOperation needs to
267278
// handle other updates.
268279
logger.trace("In onCheckpointComplete, completing operation {} ({})", getOperationId(), completionFuture);
269-
270-
markCompletionFutureCompleted();
280+
// Completing the future here will also run any other completion stages that have been attached
281+
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
282+
// so they will definitely have a chance to reactivate before we finish completing and deactivating
283+
// whatever operations were just checkpointed.
284+
completionFuture.complete(this);
271285
}
272286
}
273287

@@ -276,19 +290,11 @@ protected void markAlreadyCompleted() {
276290
// When the operation is already completed in a replay, we complete completionFuture immediately
277291
// so that the `get` method will be unblocked and the context thread will be registered
278292
logger.trace("In markAlreadyCompleted, completing operation: {} ({}).", getOperationId(), completionFuture);
279-
markCompletionFutureCompleted();
280-
}
281-
282-
private void markCompletionFutureCompleted() {
283-
// It's important that we synchronize access to the future, otherwise the processing could happen
284-
// on someone else's thread and cause a race condition.
285-
synchronized (completionFuture) {
286-
// Completing the future here will also run any other completion stages that have been attached
287-
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
288-
// so they will definitely have a chance to reactivate before we finish completing and deactivating
289-
// whatever operations were just checkpointed.
290-
completionFuture.complete(this);
291-
}
293+
// Completing the future here will also run any other completion stages that have been attached
294+
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
295+
// so they will definitely have a chance to reactivate before we finish completing and deactivating
296+
// whatever operations were just checkpointed.
297+
completionFuture.complete(this);
292298
}
293299

294300
/**

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected ConcurrencyOperation(
8282
this.toleratedFailureCount = toleratedFailureCount;
8383
this.operationIdGenerator = new OperationIdGenerator(getOperationId());
8484
this.rootContext = durableContext.createChildContext(getOperationId(), getName());
85-
this.consumerThreadListener = new AtomicReference<>(null);
85+
this.consumerThreadListener = new AtomicReference<>(new CompletableFuture<>());
8686
}
8787

8888
// ========== Template methods for subclasses ==========
@@ -138,15 +138,13 @@ protected <R> ChildContextOperation<R> enqueueItem(
138138
pendingQueue.add(childOp);
139139
logger.debug("Item enqueued {}", name);
140140
// notify the consumer thread a new item is available
141-
completeVacancyListenerIfSet();
141+
notifyConsumerThread();
142142
return childOp;
143143
}
144144

145-
private void completeVacancyListenerIfSet() {
145+
private void notifyConsumerThread() {
146146
synchronized (this) {
147-
if (consumerThreadListener.get() != null) {
148-
consumerThreadListener.get().complete(null);
149-
}
147+
consumerThreadListener.get().complete(null);
150148
}
151149
}
152150

@@ -159,6 +157,17 @@ protected void executeItems() {
159157

160158
Runnable consumer = () -> {
161159
while (true) {
160+
// Set a new future if it's completed so that it will be able to receive a notification of
161+
// new items when the thread is checking completion condition and processing
162+
// the queued items below.
163+
synchronized (this) {
164+
if (consumerThreadListener.get() != null
165+
&& consumerThreadListener.get().isDone()) {
166+
consumerThreadListener.set(new CompletableFuture<>());
167+
}
168+
}
169+
170+
// Process completion condition. Quit the loop if the condition is met.
162171
if (isOperationCompleted()) {
163172
return;
164173
}
@@ -167,14 +176,21 @@ protected void executeItems() {
167176
handleComplete(completionStatus);
168177
return;
169178
}
179+
180+
// process new items in the queue
170181
while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) {
171182
var next = pendingQueue.poll();
172183
runningChildren.add(next);
173184
logger.debug("Executing operation {}", next.getName());
174185
next.execute();
175186
}
176-
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
177-
// child may be null if the consumer thread is woken up due to a new item being added
187+
188+
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
189+
// immediately return null and repeat the above again
190+
var child = waitForChildCompletion(runningChildren);
191+
192+
// child may be null if the consumer thread is woken up due to new items added or completion condition
193+
// changed
178194
if (child != null) {
179195
if (runningChildren.contains(child)) {
180196
runningChildren.remove(child);
@@ -183,47 +199,39 @@ protected void executeItems() {
183199
throw new IllegalStateException("Unexpected completion: " + child);
184200
}
185201
}
186-
synchronized (this) {
187-
if (consumerThreadListener.get() != null
188-
&& consumerThreadListener.get().isDone()) {
189-
consumerThreadListener.set(null);
190-
}
191-
}
192202
}
193203
};
194204
// run consumer in the user thread pool, although it's not a real user thread
195205
runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT);
196206
}
197207

198-
private BaseDurableOperation waitForChildCompletion(
199-
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
208+
private BaseDurableOperation waitForChildCompletion(Set<BaseDurableOperation> runningChildren) {
200209
var threadContext = getCurrentThreadContext();
201-
CompletableFuture<Object> future;
210+
ArrayList<CompletableFuture<BaseDurableOperation>> futures = new ArrayList<>(runningChildren.stream()
211+
.map(BaseDurableOperation::getCompletionFuture)
212+
.toList());
202213

203-
synchronized (this) {
204-
// check again in synchronized block to prevent race conditions
205-
if (isOperationCompleted()) {
206-
return null;
207-
}
208-
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
209-
if (completionStatus != null) {
210-
return null;
211-
}
212-
ArrayList<CompletableFuture<BaseDurableOperation>> futures;
213-
futures = new ArrayList<>(runningChildren.stream()
214-
.map(BaseDurableOperation::getCompletionFuture)
215-
.toList());
216-
if (futures.size() < maxConcurrency) {
217-
// add a future to listen to the new items if there is a vacancy
218-
consumerThreadListener.compareAndSet(null, new CompletableFuture<>());
219-
futures.add(consumerThreadListener.get());
220-
}
214+
// always add the future to listen to the new items or condition changes. This might have been
215+
// completed during the period of consuming items from the queue.
216+
futures.add(consumerThreadListener.get());
217+
218+
// future will be completed immediately if any future of the list is already completed
219+
CompletableFuture<Object> future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
221220

222-
// future will be completed immediately if any future of the list is already completed
223-
future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
224-
// skip deregistering the current thread if there is more completed future to process
225-
if (!future.isDone()) {
226-
future.thenRun(() -> registerActiveThread(threadContext.threadId()));
221+
// skip deregistering the current thread if there is more completed future to process
222+
AtomicBoolean futureCompletedImmediately = new AtomicBoolean(false);
223+
long callerThreadId = Thread.currentThread().getId();
224+
if (!future.isDone()) {
225+
future.thenRun(() -> {
226+
if (Thread.currentThread().getId() == callerThreadId) {
227+
// If the completion thread is the same as the consumer thread (immediately completed),
228+
// we don't want to deregister and register the thread.
229+
futureCompletedImmediately.set(true);
230+
return;
231+
}
232+
registerActiveThread(threadContext.threadId());
233+
});
234+
if (!futureCompletedImmediately.get()) {
227235
// Deregister the current thread to allow suspension
228236
executionManager.deregisterActiveThread(threadContext.threadId());
229237
}
@@ -273,6 +281,8 @@ private ConcurrencyCompletionStatus canComplete(
273281
}
274282

275283
// All items finished — complete
284+
// This condition relies on isJoined, so the consumer will wake up and check this again when
285+
// isJoined is set to true.
276286
if (isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty()) {
277287
return ConcurrencyCompletionStatus.ALL_COMPLETED;
278288
}
@@ -299,8 +309,10 @@ protected void join() {
299309
+ ") exceeds the number of registered items (" + branches.size() + ")");
300310
}
301311
isJoined.set(true);
302-
// notify the execution thread this concurrency operation is joined
303-
completeVacancyListenerIfSet();
312+
313+
// Notify the consumer thread this concurrency operation is joined. Consumer thread need to check the
314+
// completion condition again.
315+
notifyConsumerThread();
304316
waitForOperationCompletion();
305317
}
306318

0 commit comments

Comments
 (0)