Skip to content

Commit c6214b3

Browse files
committed
prevent a race condition in future completion
1 parent 1284d48 commit c6214b3

3 files changed

Lines changed: 132 additions & 76 deletions

File tree

sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,11 @@ public ThreadContext getCurrentThreadContext() {
191191
* @see ThreadContext
192192
*/
193193
public void registerActiveThread(String threadId) {
194-
if (activeThreads.contains(threadId)) {
194+
if (activeThreads.add(threadId)) {
195+
logger.trace("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
196+
} else {
195197
logger.trace("Thread '{}' already registered as active", threadId);
196-
return;
197198
}
198-
activeThreads.add(threadId);
199-
logger.trace("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
200199
}
201200

202201
/**

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.List;
77
import java.util.Objects;
88
import java.util.concurrent.CompletableFuture;
9+
import java.util.concurrent.atomic.AtomicBoolean;
910
import java.util.concurrent.atomic.AtomicReference;
1011
import org.slf4j.Logger;
1112
import org.slf4j.LoggerFactory;
@@ -179,23 +180,45 @@ protected Operation waitForOperationCompletion() {
179180
// It's important that we synchronize access to the future. Otherwise, a race condition could happen if the
180181
// completionFuture is completed by a user thread (a step or child context thread) when the execution here
181182
// is between `isOperationCompleted` and `thenRun`.
182-
synchronized (completionFuture) {
183-
if (!isOperationCompleted()) {
184-
// Operation not done yet
185-
logger.trace(
186-
"deregistering thread {} when waiting for operation {} ({}) to complete ({})",
187-
threadContext.threadId(),
188-
getOperation(),
189-
getType(),
190-
completionFuture);
191-
192-
// Add a completion stage to completionFuture so that when the completionFuture is completed,
193-
// it will register the current Context thread synchronously to make sure it is always registered
194-
// strictly before the execution thread (Step or child context) is deregistered.
195-
completionFuture.thenRun(() -> registerActiveThread(threadContext.threadId()));
196-
197-
// Deregister the current thread to allow suspension
198-
executionManager.deregisterActiveThread(threadContext.threadId());
183+
AtomicBoolean skipDeregistration = new AtomicBoolean(false);
184+
long callerThreadId = Thread.currentThread().getId();
185+
186+
// Add a completion stage to completionFuture so that when the completionFuture is completed,
187+
// it will register the current Context thread synchronously to make sure it is always registered
188+
// strictly before the execution thread (Step or child context) is deregistered.
189+
completionFuture.thenRun(() -> {
190+
// at this point, future.isDone = true
191+
if (Thread.currentThread().getId() == callerThreadId) {
192+
// If the completion thread is the same as the consumer thread (immediately completed),
193+
// we don't want to deregister and register the thread.
194+
skipDeregistration.set(true);
195+
return;
196+
}
197+
198+
synchronized (skipDeregistration) {
199+
if (!skipDeregistration.get()) {
200+
registerActiveThread(threadContext.threadId());
201+
}
202+
}
203+
});
204+
205+
// skip deregistration if future is always completed or skipDeregistration is true
206+
if (!isOperationCompleted() && !completionFuture.isDone() && !skipDeregistration.get()) {
207+
synchronized (skipDeregistration) {
208+
if (!skipDeregistration.get() && !completionFuture.isDone()) {
209+
// Operation not done yet
210+
logger.trace(
211+
"deregistering thread {} when waiting for operation {} ({}) to complete ({})",
212+
threadContext.threadId(),
213+
getOperation(),
214+
getType(),
215+
completionFuture);
216+
217+
// Deregister the current thread to allow suspension
218+
executionManager.deregisterActiveThread(threadContext.threadId());
219+
} else {
220+
skipDeregistration.set(true);
221+
}
199222
}
200223
}
201224

@@ -282,13 +305,11 @@ protected void markAlreadyCompleted() {
282305
private void markCompletionFutureCompleted() {
283306
// It's important that we synchronize access to the future, otherwise the processing could happen
284307
// on someone else's thread and cause a race condition.
285-
synchronized (completionFuture) {
286-
// Completing the future here will also run any other completion stages that have been attached
287-
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
288-
// so they will definitely have a chance to reactivate before we finish completing and deactivating
289-
// whatever operations were just checkpointed.
290-
completionFuture.complete(this);
291-
}
308+
// Completing the future here will also run any other completion stages that have been attached
309+
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
310+
// so they will definitely have a chance to reactivate before we finish completing and deactivating
311+
// whatever operations were just checkpointed.
312+
completionFuture.complete(this);
292313
}
293314

294315
/**

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected ConcurrencyOperation(
8282
this.toleratedFailureCount = toleratedFailureCount;
8383
this.operationIdGenerator = new OperationIdGenerator(getOperationId());
8484
this.rootContext = durableContext.createChildContext(getOperationId(), getName());
85-
this.consumerThreadListener = new AtomicReference<>(null);
85+
this.consumerThreadListener = new AtomicReference<>(new CompletableFuture<>());
8686
}
8787

8888
// ========== Template methods for subclasses ==========
@@ -138,15 +138,13 @@ protected <R> ChildContextOperation<R> enqueueItem(
138138
pendingQueue.add(childOp);
139139
logger.debug("Item enqueued {}", name);
140140
// notify the consumer thread a new item is available
141-
completeVacancyListenerIfSet();
141+
notifyConsumerThread();
142142
return childOp;
143143
}
144144

145-
private void completeVacancyListenerIfSet() {
146-
synchronized (this) {
147-
if (consumerThreadListener.get() != null) {
148-
consumerThreadListener.get().complete(null);
149-
}
145+
private void notifyConsumerThread() {
146+
synchronized (consumerThreadListener) {
147+
consumerThreadListener.get().complete(null);
150148
}
151149
}
152150

@@ -159,6 +157,17 @@ protected void executeItems() {
159157

160158
Runnable consumer = () -> {
161159
while (true) {
160+
// Set a new future if it's completed so that it will be able to receive a notification of
161+
// new items when the thread is checking completion condition and processing
162+
// the queued items below.
163+
synchronized (consumerThreadListener) {
164+
if (consumerThreadListener.get() != null
165+
&& consumerThreadListener.get().isDone()) {
166+
consumerThreadListener.set(new CompletableFuture<>());
167+
}
168+
}
169+
170+
// Process completion condition. Quit the loop if the condition is met.
162171
if (isOperationCompleted()) {
163172
return;
164173
}
@@ -167,14 +176,21 @@ protected void executeItems() {
167176
handleComplete(completionStatus);
168177
return;
169178
}
179+
180+
// process new items in the queue
170181
while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) {
171182
var next = pendingQueue.poll();
172183
runningChildren.add(next);
173184
logger.debug("Executing operation {}", next.getName());
174185
next.execute();
175186
}
176-
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
177-
// child may be null if the consumer thread is woken up due to a new item being added
187+
188+
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
189+
// immediately return null and repeat the above again
190+
var child = waitForChildCompletion(runningChildren);
191+
192+
// child may be null if the consumer thread is woken up due to new items added or completion condition
193+
// changed
178194
if (child != null) {
179195
if (runningChildren.contains(child)) {
180196
runningChildren.remove(child);
@@ -183,49 +199,67 @@ protected void executeItems() {
183199
throw new IllegalStateException("Unexpected completion: " + child);
184200
}
185201
}
186-
synchronized (this) {
187-
if (consumerThreadListener.get() != null
188-
&& consumerThreadListener.get().isDone()) {
189-
consumerThreadListener.set(null);
190-
}
191-
}
192202
}
193203
};
194204
// run consumer in the user thread pool, although it's not a real user thread
195205
runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT);
196206
}
197207

198-
private BaseDurableOperation waitForChildCompletion(
199-
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
208+
private BaseDurableOperation waitForChildCompletion(Set<BaseDurableOperation> runningChildren) {
200209
var threadContext = getCurrentThreadContext();
201-
CompletableFuture<Object> future;
210+
ArrayList<CompletableFuture<BaseDurableOperation>> futures = new ArrayList<>(runningChildren.stream()
211+
.map(BaseDurableOperation::getCompletionFuture)
212+
.toList());
202213

203-
synchronized (this) {
204-
// check again in synchronized block to prevent race conditions
205-
if (isOperationCompleted()) {
206-
return null;
207-
}
208-
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
209-
if (completionStatus != null) {
210-
return null;
214+
// always add the future to listen to the new items or condition changes. This might have been
215+
// completed during the period of consuming items from the queue.
216+
futures.add(consumerThreadListener.get());
217+
218+
// future will be completed immediately if any future of the list is already completed
219+
CompletableFuture<Object> future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
220+
221+
// skip deregistering the current thread if it is already completed
222+
AtomicBoolean skipDeregistration = new AtomicBoolean(false);
223+
long callerThreadId = Thread.currentThread().getId();
224+
225+
// attach the stage for registering thread without checking future status to avoid race condition
226+
future.thenRun(() -> {
227+
// at this point, future.isDone = true
228+
if (Thread.currentThread().getId() == callerThreadId) {
229+
// If the completion thread is the same as the consumer thread (immediately completed),
230+
// we don't want to deregister and register the thread.
231+
skipDeregistration.set(true);
232+
return;
211233
}
212-
ArrayList<CompletableFuture<BaseDurableOperation>> futures;
213-
futures = new ArrayList<>(runningChildren.stream()
214-
.map(BaseDurableOperation::getCompletionFuture)
215-
.toList());
216-
if (futures.size() < maxConcurrency) {
217-
// add a future to listen to the new items if there is a vacancy
218-
consumerThreadListener.compareAndSet(null, new CompletableFuture<>());
219-
futures.add(consumerThreadListener.get());
234+
235+
synchronized (skipDeregistration) {
236+
if (!skipDeregistration.get()) {
237+
registerActiveThread(threadContext.threadId());
238+
}
220239
}
240+
});
241+
242+
// skip deregistration if future is always completed or skipDeregistration is true
243+
if (!future.isDone() && !skipDeregistration.get()) {
244+
// We tried best to skip deregistration/registration using thread id check above. However,
245+
// there is still a small window where the completionFuture is completed right after the check
246+
// and before the deregisterActiveThread call.
221247

222-
// future will be completed immediately if any future of the list is already completed
223-
future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
224-
// skip deregistering the current thread if there is more completed future to process
225-
if (!future.isDone()) {
226-
future.thenRun(() -> registerActiveThread(threadContext.threadId()));
227-
// Deregister the current thread to allow suspension
228-
executionManager.deregisterActiveThread(threadContext.threadId());
248+
// Registering and deregistering threads might be competing for this lock.
249+
// If this deregistration thread skips deregistering, skipDeregistration will be set to true
250+
// so that the registration thread will skip registration
251+
synchronized (skipDeregistration) {
252+
if (!skipDeregistration.get() && !future.isDone()) {
253+
// Operation not done yet
254+
logger.trace(
255+
"deregistering consumer thread {} when waiting for {} branches to complete",
256+
threadContext.threadId(),
257+
getType());
258+
// Deregister the current thread to allow suspension
259+
executionManager.deregisterActiveThread(threadContext.threadId());
260+
} else {
261+
skipDeregistration.set(true);
262+
}
229263
}
230264
}
231265
return future.thenApply(o -> (BaseDurableOperation) o).join();
@@ -273,6 +307,8 @@ private ConcurrencyCompletionStatus canComplete(
273307
}
274308

275309
// All items finished — complete
310+
// This condition relies on isJoined, so the consumer will wake up and check this again when
311+
// isJoined is set to true.
276312
if (isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty()) {
277313
return ConcurrencyCompletionStatus.ALL_COMPLETED;
278314
}
@@ -281,12 +317,10 @@ private ConcurrencyCompletionStatus canComplete(
281317
}
282318

283319
private void handleComplete(ConcurrencyCompletionStatus status) {
284-
synchronized (this) {
285-
if (isOperationCompleted()) {
286-
return;
287-
}
288-
handleSuccess(status);
320+
if (isOperationCompleted()) {
321+
return;
289322
}
323+
handleSuccess(status);
290324
}
291325

292326
/**
@@ -299,8 +333,10 @@ protected void join() {
299333
+ ") exceeds the number of registered items (" + branches.size() + ")");
300334
}
301335
isJoined.set(true);
302-
// notify the execution thread this concurrency operation is joined
303-
completeVacancyListenerIfSet();
336+
337+
// Notify the consumer thread this concurrency operation is joined. Consumer thread need to check the
338+
// completion condition again.
339+
notifyConsumerThread();
304340
waitForOperationCompletion();
305341
}
306342

0 commit comments

Comments
 (0)