Skip to content

Commit 263864e

Browse files
authored
prevent a race condition in concurrent operation (aws#250)
1 parent 0c8c2ad commit 263864e

6 files changed

Lines changed: 72 additions & 47 deletions

File tree

sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,13 @@ public ThreadContext getCurrentThreadContext() {
191191
* @see ThreadContext
192192
*/
193193
public void registerActiveThread(String threadId) {
194-
if (activeThreads.contains(threadId)) {
195-
logger.trace("Thread '{}' already registered as active", threadId);
196-
return;
194+
synchronized (activeThreads) {
195+
if (activeThreads.add(threadId)) {
196+
logger.trace("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
197+
} else {
198+
logger.warn("Thread '{}' already registered as active", threadId);
199+
}
197200
}
198-
activeThreads.add(threadId);
199-
logger.trace("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
200201
}
201202

202203
/**
@@ -210,16 +211,20 @@ public void deregisterActiveThread(String threadId) {
210211
return;
211212
}
212213

213-
boolean removed = activeThreads.remove(threadId);
214-
if (removed) {
215-
logger.trace("Deregistered thread '{}' Active threads: {}", threadId, activeThreads.size());
216-
} else {
217-
logger.warn("Thread '{}' not active, cannot deregister", threadId);
218-
}
214+
// Add synchronized block to avoid remove then check race condition and make sure that
215+
// the suspendExecution is called only once
216+
synchronized (activeThreads) {
217+
boolean removed = activeThreads.remove(threadId);
218+
if (removed) {
219+
logger.trace("Deregistered thread '{}' Active threads: {}", threadId, activeThreads.size());
220+
} else {
221+
logger.warn("Thread '{}' not active, cannot deregister", threadId);
222+
}
219223

220-
if (activeThreads.isEmpty()) {
221-
logger.info("No active threads remaining - suspending execution");
222-
suspendExecution();
224+
if (activeThreads.isEmpty()) {
225+
logger.info("No active threads remaining - suspending execution");
226+
suspendExecution();
227+
}
223228
}
224229
}
225230

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ public abstract class BaseDurableOperation {
4747
private final OperationIdentifier operationIdentifier;
4848
protected final ExecutionManager executionManager;
4949
protected final CompletableFuture<BaseDurableOperation> completionFuture;
50+
protected final BaseDurableOperation parentOperation;
5051
private final DurableContextImpl durableContext;
5152
private final AtomicReference<CompletableFuture<Void>> runningUserHandler = new AtomicReference<>(null);
5253

@@ -55,9 +56,14 @@ public abstract class BaseDurableOperation {
5556
*
5657
* @param operationIdentifier the unique identifier for this operation
5758
* @param durableContext the parent context this operation belongs to
59+
* @param parentOperation the parent operation if this is a branch/iteration of a ConcurrencyOperation
5860
*/
59-
protected BaseDurableOperation(OperationIdentifier operationIdentifier, DurableContextImpl durableContext) {
61+
protected BaseDurableOperation(
62+
OperationIdentifier operationIdentifier,
63+
DurableContextImpl durableContext,
64+
BaseDurableOperation parentOperation) {
6065
this.operationIdentifier = operationIdentifier;
66+
this.parentOperation = parentOperation;
6167
this.durableContext = durableContext;
6268
this.executionManager = durableContext.getExecutionManager();
6369

@@ -179,7 +185,9 @@ protected Operation waitForOperationCompletion() {
179185
// It's important that we synchronize access to the future. Otherwise, a race condition could happen if the
180186
// completionFuture is completed by a user thread (a step or child context thread) when the execution here
181187
// is between `isOperationCompleted` and `thenRun`.
182-
synchronized (completionFuture) {
188+
// If this operation is a branch/iteration of a ConcurrencyOperation (map or parallel), the branches/iterations
189+
// must be completed sequentially to avoid race conditions.
190+
synchronized (parentOperation == null ? completionFuture : parentOperation) {
183191
if (!isOperationCompleted()) {
184192
// Operation not done yet
185193
logger.trace(
@@ -282,7 +290,7 @@ protected void markAlreadyCompleted() {
282290
private void markCompletionFutureCompleted() {
283291
// It's important that we synchronize access to the future, otherwise the processing could happen
284292
// on someone else's thread and cause a race condition.
285-
synchronized (completionFuture) {
293+
synchronized (parentOperation == null ? completionFuture : parentOperation) {
286294
// Completing the future here will also run any other completion stages that have been attached
287295
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
288296
// so they will definitely have a chance to reactivate before we finish completing and deactivating

sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ public class ChildContextOperation<T> extends SerializableDurableOperation<T> {
4646
private static final int LARGE_RESULT_THRESHOLD = 256 * 1024;
4747

4848
private final Function<DurableContext, T> function;
49-
private final ConcurrencyOperation<?> parentOperation;
5049
private final AtomicBoolean replayChildren = new AtomicBoolean(false);
5150
private T reconstructedResult;
5251

@@ -66,9 +65,8 @@ public ChildContextOperation(
6665
RunInChildContextConfig config,
6766
DurableContextImpl durableContext,
6867
ConcurrencyOperation<?> parentOperation) {
69-
super(operationIdentifier, resultTypeToken, config.serDes(), durableContext);
68+
super(operationIdentifier, resultTypeToken, config.serDes(), durableContext, parentOperation);
7069
this.function = function;
71-
this.parentOperation = parentOperation;
7270
}
7371

7472
/** Starts the operation. */

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected ConcurrencyOperation(
8282
this.toleratedFailureCount = toleratedFailureCount;
8383
this.operationIdGenerator = new OperationIdGenerator(getOperationId());
8484
this.rootContext = durableContext.createChildContext(getOperationId(), getName());
85-
this.consumerThreadListener = new AtomicReference<>(null);
85+
this.consumerThreadListener = new AtomicReference<>(new CompletableFuture<>());
8686
}
8787

8888
// ========== Template methods for subclasses ==========
@@ -138,15 +138,13 @@ protected <R> ChildContextOperation<R> enqueueItem(
138138
pendingQueue.add(childOp);
139139
logger.debug("Item enqueued {}", name);
140140
// notify the consumer thread a new item is available
141-
completeVacancyListenerIfSet();
141+
notifyConsumerThread();
142142
return childOp;
143143
}
144144

145-
private void completeVacancyListenerIfSet() {
145+
private void notifyConsumerThread() {
146146
synchronized (this) {
147-
if (consumerThreadListener.get() != null) {
148-
consumerThreadListener.get().complete(null);
149-
}
147+
consumerThreadListener.get().complete(null);
150148
}
151149
}
152150

@@ -159,22 +157,40 @@ protected void executeItems() {
159157

160158
Runnable consumer = () -> {
161159
while (true) {
160+
// Set a new future if it's completed so that it will be able to receive a notification of
161+
// new items when the thread is checking completion condition and processing
162+
// the queued items below.
163+
synchronized (this) {
164+
if (consumerThreadListener.get() != null
165+
&& consumerThreadListener.get().isDone()) {
166+
consumerThreadListener.set(new CompletableFuture<>());
167+
}
168+
}
169+
170+
// Process completion condition. Quit the loop if the condition is met.
162171
if (isOperationCompleted()) {
163172
return;
164173
}
165174
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
166175
if (completionStatus != null) {
167-
handleComplete(completionStatus);
176+
handleSuccess(completionStatus);
168177
return;
169178
}
179+
180+
// process new items in the queue
170181
while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) {
171182
var next = pendingQueue.poll();
172183
runningChildren.add(next);
173184
logger.debug("Executing operation {}", next.getName());
174185
next.execute();
175186
}
187+
188+
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
189+
// immediately return null and repeat the above again
176190
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
177-
// child may be null if the consumer thread is woken up due to a new item being added
191+
192+
// child may be null if the consumer thread is woken up due to new items added or completion condition
193+
// changed
178194
if (child != null) {
179195
if (runningChildren.contains(child)) {
180196
runningChildren.remove(child);
@@ -183,12 +199,6 @@ protected void executeItems() {
183199
throw new IllegalStateException("Unexpected completion: " + child);
184200
}
185201
}
186-
synchronized (this) {
187-
if (consumerThreadListener.get() != null
188-
&& consumerThreadListener.get().isDone()) {
189-
consumerThreadListener.set(null);
190-
}
191-
}
192202
}
193203
};
194204
// run consumer in the user thread pool, although it's not a real user thread
@@ -273,22 +283,15 @@ private ConcurrencyCompletionStatus canComplete(
273283
}
274284

275285
// All items finished — complete
286+
// This condition relies on isJoined, so the consumer will wake up and check this again when
287+
// isJoined is set to true.
276288
if (isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty()) {
277289
return ConcurrencyCompletionStatus.ALL_COMPLETED;
278290
}
279291

280292
return null;
281293
}
282294

283-
private void handleComplete(ConcurrencyCompletionStatus status) {
284-
synchronized (this) {
285-
if (isOperationCompleted()) {
286-
return;
287-
}
288-
handleSuccess(status);
289-
}
290-
}
291-
292295
/**
293296
* Blocks the calling thread until the concurrency operation reaches a terminal state. Validates item count, handles
294297
* zero-branch case, then delegates to {@code waitForOperationCompletion()} from BaseDurableOperation.
@@ -299,8 +302,10 @@ protected void join() {
299302
+ ") exceeds the number of registered items (" + branches.size() + ")");
300303
}
301304
isJoined.set(true);
302-
// notify the execution thread this concurrency operation is joined
303-
completeVacancyListenerIfSet();
305+
306+
// Notify the consumer thread this concurrency operation is joined. Consumer thread need to check the
307+
// completion condition again.
308+
notifyConsumerThread();
304309
waitForOperationCompletion();
305310
}
306311

sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,16 @@ protected SerializableDurableOperation(
5252
TypeToken<T> resultTypeToken,
5353
SerDes resultSerDes,
5454
DurableContextImpl durableContext) {
55-
super(operationIdentifier, durableContext);
55+
this(operationIdentifier, resultTypeToken, resultSerDes, durableContext, null);
56+
}
57+
58+
protected SerializableDurableOperation(
59+
OperationIdentifier operationIdentifier,
60+
TypeToken<T> resultTypeToken,
61+
SerDes resultSerDes,
62+
DurableContextImpl durableContext,
63+
BaseDurableOperation parentOperation) {
64+
super(operationIdentifier, durableContext, parentOperation);
5665
this.resultTypeToken = resultTypeToken;
5766
this.resultSerDes = resultSerDes;
5867
}

sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class WaitOperation extends BaseDurableOperation implements DurableFuture
2929

3030
public WaitOperation(
3131
OperationIdentifier operationIdentifier, Duration duration, DurableContextImpl durableContext) {
32-
super(operationIdentifier, durableContext);
32+
super(operationIdentifier, durableContext, null);
3333
this.duration = duration;
3434
}
3535

0 commit comments

Comments
 (0)