Skip to content

Commit c3316dd

Browse files
committed
fix race conditions in concurrency operation
1 parent af82eea commit c3316dd

17 files changed

Lines changed: 253 additions & 590 deletions

sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import com.amazonaws.services.lambda.runtime.Context;
66
import software.amazon.lambda.durable.DurableConfig;
77
import software.amazon.lambda.durable.execution.ExecutionManager;
8-
import software.amazon.lambda.durable.execution.SuspendExecutionException;
9-
import software.amazon.lambda.durable.execution.ThreadContext;
108
import software.amazon.lambda.durable.execution.ThreadType;
119

1210
public abstract class BaseContextImpl implements AutoCloseable, BaseContext {
@@ -36,40 +34,13 @@ protected BaseContextImpl(
3634
String contextId,
3735
String contextName,
3836
ThreadType threadType) {
39-
this(executionManager, durableConfig, lambdaContext, contextId, contextName, threadType, true);
40-
}
41-
42-
/**
43-
* Creates a new BaseContext instance.
44-
*
45-
* @param executionManager the execution manager for thread coordination and state management
46-
* @param durableConfig the durable execution configuration
47-
* @param lambdaContext the AWS Lambda runtime context
48-
* @param contextId the context ID, null for root context, set for child contexts
49-
* @param contextName the human-readable name for this context
50-
* @param threadType the type of thread this context runs on
51-
* @param setCurrentThreadContext whether to call setCurrentThreadContext on the execution manager
52-
*/
53-
protected BaseContextImpl(
54-
ExecutionManager executionManager,
55-
DurableConfig durableConfig,
56-
Context lambdaContext,
57-
String contextId,
58-
String contextName,
59-
ThreadType threadType,
60-
boolean setCurrentThreadContext) {
6137
this.executionManager = executionManager;
6238
this.durableConfig = durableConfig;
6339
this.lambdaContext = lambdaContext;
6440
this.contextId = contextId;
6541
this.contextName = contextName;
6642
this.isReplaying = executionManager.hasOperationsForContext(contextId);
6743
this.threadType = threadType;
68-
69-
if (setCurrentThreadContext) {
70-
// write the thread id and type to thread local
71-
executionManager.setCurrentThreadContext(new ThreadContext(contextId, threadType));
72-
}
7344
}
7445

7546
// =============== accessors ================
@@ -138,23 +109,4 @@ public boolean isReplaying() {
138109
public void setExecutionMode() {
139110
this.isReplaying = false;
140111
}
141-
142-
@Override
143-
public void close() {
144-
// this is called in the user thread, after the context's user code has completed
145-
if (getContextId() != null) {
146-
// if this is a child context or a step context, we need to
147-
// deregister the context's thread from the execution manager
148-
try {
149-
executionManager.deregisterActiveThread(getContextId());
150-
} catch (SuspendExecutionException e) {
151-
// Expected when this is the last active thread. Must catch here because:
152-
// 1/ This runs in a worker thread detached from handlerFuture
153-
// 2/ Uncaught exception would prevent stepAsync().get() from resume
154-
// Suspension/Termination is already signaled via
155-
// suspendExecutionFuture/terminateExecutionFuture
156-
// before the throw.
157-
}
158-
}
159-
}
160112
}

sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,7 @@ private DurableContextImpl(
6767
Context lambdaContext,
6868
String contextId,
6969
String contextName) {
70-
this(executionManager, durableConfig, lambdaContext, contextId, contextName, true);
71-
}
72-
73-
private DurableContextImpl(
74-
ExecutionManager executionManager,
75-
DurableConfig durableConfig,
76-
Context lambdaContext,
77-
String contextId,
78-
String contextName,
79-
boolean setCurrentThreadContext) {
80-
super(
81-
executionManager,
82-
durableConfig,
83-
lambdaContext,
84-
contextId,
85-
contextName,
86-
ThreadType.CONTEXT,
87-
setCurrentThreadContext);
70+
super(executionManager, durableConfig, lambdaContext, contextId, contextName, ThreadType.CONTEXT);
8871
operationIdGenerator = new OperationIdGenerator(contextId);
8972
}
9073

@@ -115,22 +98,6 @@ public DurableContextImpl createChildContext(String childContextId, String child
11598
getExecutionManager(), getDurableConfig(), getLambdaContext(), childContextId, childContextName);
11699
}
117100

118-
/**
119-
* Creates a child context without setting the current thread context.
120-
*
121-
* <p>Use this when the child context is being created on a thread that should not have its thread-local context
122-
* overwritten (e.g. when constructing the context ahead of running it on a separate thread).
123-
*
124-
* @param childContextId the child context's ID (the CONTEXT operation's operation ID)
125-
* @param childContextName the name of the child context
126-
* @return a new DurableContext for the child context
127-
*/
128-
public DurableContextImpl createChildContextWithoutSettingThreadContext(
129-
String childContextId, String childContextName) {
130-
return new DurableContextImpl(
131-
getExecutionManager(), getDurableConfig(), getLambdaContext(), childContextId, childContextName, false);
132-
}
133-
134101
/**
135102
* Creates a step context for executing step operations.
136103
*
@@ -418,7 +385,6 @@ public void close() {
418385
if (logger != null) {
419386
logger.close();
420387
}
421-
super.close();
422388
}
423389

424390
/**

sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,5 @@ public void close() {
6666
if (logger != null) {
6767
logger.close();
6868
}
69-
super.close();
7069
}
7170
}

sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public static <I, O> DurableExecutionOutput execute(
5151
executionManager.registerActiveThread(null);
5252
var handlerFuture = CompletableFuture.supplyAsync(
5353
() -> {
54+
executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT));
5455
var userInput = extractUserInput(
5556
executionManager.getExecutionOperation(), config.getSerDes(), inputType);
5657
// use try-with-resources to clear logger properties

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.List;
77
import java.util.Objects;
88
import java.util.concurrent.CompletableFuture;
9+
import java.util.concurrent.atomic.AtomicReference;
910
import org.slf4j.Logger;
1011
import org.slf4j.LoggerFactory;
1112
import software.amazon.awssdk.services.lambda.model.Operation;
@@ -16,6 +17,7 @@
1617
import software.amazon.lambda.durable.exception.NonDeterministicExecutionException;
1718
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
1819
import software.amazon.lambda.durable.execution.ExecutionManager;
20+
import software.amazon.lambda.durable.execution.SuspendExecutionException;
1921
import software.amazon.lambda.durable.execution.ThreadContext;
2022
import software.amazon.lambda.durable.execution.ThreadType;
2123
import software.amazon.lambda.durable.model.OperationIdentifier;
@@ -46,6 +48,7 @@ public abstract class BaseDurableOperation {
4648
protected final ExecutionManager executionManager;
4749
protected final CompletableFuture<BaseDurableOperation> completionFuture;
4850
private final DurableContextImpl durableContext;
51+
private final AtomicReference<CompletableFuture<Void>> runningUserHandler = new AtomicReference<>(null);
4952

5053
/**
5154
* Constructs a new durable operation.
@@ -152,7 +155,7 @@ private void validateCurrentThreadType() {
152155
"Nested %s operation is not supported on %s from within a %s execution.",
153156
getType(), getName(), current);
154157
// terminate execution and throw the exception
155-
terminateExecutionWithIllegalDurableOperationException(message);
158+
throw terminateExecutionWithIllegalDurableOperationException(message);
156159
}
157160
}
158161

@@ -208,6 +211,50 @@ protected Operation waitForOperationCompletion() {
208211
return op;
209212
}
210213

214+
protected void runUserHandler(Runnable runnable, String contextId, ThreadType threadType) {
215+
Runnable wrapped = () -> {
216+
executionManager.setCurrentThreadContext(new ThreadContext(contextId, threadType));
217+
try {
218+
runnable.run();
219+
} finally {
220+
if (contextId != null) {
221+
try {
222+
// if this is a child context or a step context, we need to
223+
// deregister the context's thread from the execution manager
224+
executionManager.deregisterActiveThread(contextId);
225+
} catch (SuspendExecutionException e) {
226+
// Expected when this is the last active thread. Must catch here because:
227+
// 1/ This runs in a worker thread detached from handlerFuture
228+
// 2/ Uncaught exception would prevent stepAsync().get() from resume
229+
// Suspension/Termination is already signaled via
230+
// suspendExecutionFuture/terminateExecutionFuture
231+
// before the throw.
232+
}
233+
}
234+
}
235+
};
236+
237+
// runUserHandler is used to ensure that only one user handler is running at a time
238+
if (runningUserHandler.get() != null) {
239+
throw new IllegalStateException("User handler already running");
240+
}
241+
242+
// Thread registration is intentionally split across two threads:
243+
// 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the
244+
// parent can deregister and trigger suspension (race prevention).
245+
// 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside
246+
// the child context know which context they belong to.
247+
// registerActiveThread is idempotent (no-op if already registered).
248+
registerActiveThread(contextId);
249+
250+
if (!runningUserHandler.compareAndSet(
251+
null,
252+
CompletableFuture.runAsync(
253+
wrapped, getContext().getDurableConfig().getExecutorService()))) {
254+
throw new IllegalStateException("User handler already running");
255+
}
256+
}
257+
211258
/**
212259
* Receives operation updates from ExecutionManager. Completes the internal future when the operation reaches a
213260
* terminal status, unblocking any threads waiting on this operation.
@@ -317,21 +364,21 @@ protected void validateReplay(Operation checkpointed) {
317364
}
318365

319366
if (!checkpointed.type().equals(getType())) {
320-
terminateExecution(new NonDeterministicExecutionException(String.format(
367+
throw terminateExecution(new NonDeterministicExecutionException(String.format(
321368
"Operation type mismatch for \"%s\". Expected %s, got %s",
322369
getOperationId(), checkpointed.type(), getType())));
323370
}
324371

325372
if (!Objects.equals(checkpointed.name(), getName())) {
326-
terminateExecution(new NonDeterministicExecutionException(String.format(
373+
throw terminateExecution(new NonDeterministicExecutionException(String.format(
327374
"Operation name mismatch for \"%s\". Expected \"%s\", got \"%s\"",
328375
getOperationId(), checkpointed.name(), getName())));
329376
}
330377

331378
if ((getSubType() == null && checkpointed.subType() != null)
332379
|| getSubType() != null
333380
&& !Objects.equals(checkpointed.subType(), getSubType().getValue())) {
334-
terminateExecution(new NonDeterministicExecutionException(String.format(
381+
throw terminateExecution(new NonDeterministicExecutionException(String.format(
335382
"Operation subType mismatch for \"%s\". Expected \"%s\", got \"%s\"",
336383
getOperationId(), checkpointed.subType(), getSubType())));
337384
}

sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ protected void replay(Operation existing) {
6565
// Still waiting - continue to polling
6666
}
6767
default ->
68-
terminateExecutionWithIllegalDurableOperationException(
68+
throw terminateExecutionWithIllegalDurableOperationException(
6969
"Unexpected callback status: " + existing.status());
7070
}
7171
pollForOperationUpdates();

sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import static software.amazon.lambda.durable.execution.ExecutionManager.isTerminalStatus;
66

77
import java.nio.charset.StandardCharsets;
8-
import java.util.concurrent.CompletableFuture;
9-
import java.util.concurrent.ExecutorService;
108
import java.util.function.Function;
119
import software.amazon.awssdk.services.lambda.model.ContextOptions;
1210
import software.amazon.awssdk.services.lambda.model.ErrorObject;
@@ -28,6 +26,7 @@
2826
import software.amazon.lambda.durable.exception.StepInterruptedException;
2927
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
3028
import software.amazon.lambda.durable.execution.SuspendExecutionException;
29+
import software.amazon.lambda.durable.execution.ThreadType;
3130
import software.amazon.lambda.durable.model.OperationIdentifier;
3231
import software.amazon.lambda.durable.util.ExceptionHelper;
3332

@@ -46,7 +45,6 @@ public class ChildContextOperation<T> extends SerializableDurableOperation<T> {
4645
private static final int LARGE_RESULT_THRESHOLD = 256 * 1024;
4746

4847
private final Function<DurableContext, T> function;
49-
private final ExecutorService userExecutor;
5048
private final ConcurrencyOperation<?> parentOperation;
5149
private boolean replayChildContext;
5250
private T reconstructedResult;
@@ -69,7 +67,6 @@ public ChildContextOperation(
6967
ConcurrencyOperation<?> parentOperation) {
7068
super(operationIdentifier, resultTypeToken, config.serDes(), durableContext);
7169
this.function = function;
72-
this.userExecutor = getContext().getDurableConfig().getExecutorService();
7370
this.parentOperation = parentOperation;
7471
}
7572

@@ -116,14 +113,6 @@ private void executeChildContext() {
116113
// third level child context "hash(hash(hash(1)-2)-1)".
117114
var contextId = getOperationId();
118115

119-
// Thread registration is intentionally split across two threads:
120-
// 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the
121-
// parent can deregister and trigger suspension (race prevention).
122-
// 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside
123-
// the child context know which context they belong to.
124-
// registerActiveThread is idempotent (no-op if already registered).
125-
registerActiveThread(contextId);
126-
127116
Runnable userHandler = () -> {
128117
// use a try-with-resources to
129118
// - add thread id/type to thread local when the step starts
@@ -144,7 +133,7 @@ private void executeChildContext() {
144133
};
145134

146135
// Execute user provided child context code in user-configured executor
147-
CompletableFuture.runAsync(userHandler, userExecutor);
136+
runUserHandler(userHandler, contextId, ThreadType.CONTEXT);
148137
}
149138

150139
private void handleChildContextSuccess(T result) {
@@ -192,7 +181,7 @@ private void handleChildContextFailure(Throwable exception) {
192181
}
193182
if (exception instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) {
194183
// terminate the execution and throw the exception if it's not recoverable
195-
terminateExecution(unrecoverableDurableExecutionException);
184+
throw terminateExecution(unrecoverableDurableExecutionException);
196185
}
197186

198187
// Skip checkpointing if parent ConcurrencyOperation has already completed —

0 commit comments

Comments
 (0)