Skip to content

Commit bb37755

Browse files
committed
rename thread context methods
1 parent 2f46065 commit bb37755

10 files changed

Lines changed: 69 additions & 62 deletions

File tree

sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package com.amazonaws.lambda.durable;
44

55
import com.amazonaws.lambda.durable.execution.ExecutionManager;
6+
import com.amazonaws.lambda.durable.execution.ThreadContext;
67
import com.amazonaws.lambda.durable.execution.ThreadType;
78
import com.amazonaws.lambda.durable.logging.DurableLogger;
89
import com.amazonaws.lambda.durable.operation.CallbackOperation;
@@ -43,8 +44,8 @@ public class DurableContext {
4344
durableConfig.getLoggerConfig().suppressReplayLogs());
4445

4546
// Register root context thread as active
46-
executionManager.registerActiveThread(contextId, ThreadType.CONTEXT);
47-
executionManager.setCurrentContext(contextId, ThreadType.CONTEXT);
47+
executionManager.registerActiveThread(contextId);
48+
executionManager.setCurrentThreadContext(new ThreadContext(contextId, ThreadType.CONTEXT));
4849
}
4950

5051
DurableContext(ExecutionManager executionManager, DurableConfig config, Context lambdaContext) {

sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import java.time.Duration;
99
import java.util.Collections;
1010
import java.util.HashMap;
11+
import java.util.HashSet;
1112
import java.util.List;
1213
import java.util.Map;
14+
import java.util.Set;
1315
import java.util.concurrent.CompletableFuture;
1416
import java.util.concurrent.atomic.AtomicReference;
1517
import java.util.stream.Collectors;
@@ -51,8 +53,8 @@ public class ExecutionManager {
5153
// ===== Thread Coordination =====
5254
private final Map<String, BaseDurableOperation<?>> registeredOperations =
5355
Collections.synchronizedMap(new HashMap<>());
54-
private final Map<String, ThreadType> activeThreads = Collections.synchronizedMap(new HashMap<>());
55-
private static final ThreadLocal<OperationContext> currentContext = new ThreadLocal<>();
56+
private final Set<String> activeThreads = Collections.synchronizedSet(new HashSet<>());
57+
private static final ThreadLocal<ThreadContext> currentThreadContext = new ThreadLocal<>();
5658
private final CompletableFuture<Void> executionExceptionFuture = new CompletableFuture<>();
5759

5860
// ===== Checkpoint Batching =====
@@ -130,53 +132,48 @@ public Operation getExecutionOperation() {
130132
}
131133

132134
// ===== Thread Coordination =====
135+
/** Sets the current thread's ThreadContext (threadId and threadType). Called when a user thread is started. */
136+
public void setCurrentThreadContext(ThreadContext threadContext) {
137+
currentThreadContext.set(threadContext);
138+
}
139+
140+
/** Returns the current thread's ThreadContext (threadId and threadType), or null if not set. */
141+
public ThreadContext getCurrentThreadContext() {
142+
return currentThreadContext.get();
143+
}
144+
133145
/**
134-
* Registers a thread as active without setting the thread local OperationContext. Use this when registration must
135-
* happen on a different thread than execution. Call setCurrentContext() on the execution thread to set the local
136-
* OperationContext.
146+
* Registers a thread as active.
137147
*
138-
* @see OperationContext
148+
* @see ThreadContext
139149
*/
140-
public void registerActiveThread(String threadId, ThreadType threadType) {
141-
if (activeThreads.containsKey(threadId)) {
142-
logger.trace("Thread '{}' ({}) already registered as active", threadId, threadType);
150+
public void registerActiveThread(String threadId) {
151+
if (activeThreads.contains(threadId)) {
152+
logger.trace("Thread '{}' already registered as active", threadId);
143153
return;
144154
}
145-
activeThreads.put(threadId, threadType);
146-
logger.trace(
147-
"Registered thread '{}' ({}) as active (no context). Active threads: {}",
148-
threadId,
149-
threadType,
150-
activeThreads.size());
155+
activeThreads.add(threadId);
156+
logger.trace("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
151157
}
152158

153159
/**
154-
* Sets the current thread's context. Use after registerActiveThreadWithoutContext() when the execution thread is
155-
* different from the registration thread.
160+
* Mark a thread as inactive. If no threads remain, suspends the execution.
161+
*
162+
* @param threadId the thread ID to deregister
156163
*/
157-
public void setCurrentContext(String contextId, ThreadType threadType) {
158-
currentContext.set(new OperationContext(contextId, threadType));
159-
}
160-
161-
/** Returns the current thread's context, or null if not set. */
162-
public OperationContext getCurrentContext() {
163-
return currentContext.get();
164-
}
165-
166164
public void deregisterActiveThread(String threadId) {
167165
// Skip if already suspended
168166
if (executionExceptionFuture.isDone()) {
169167
return;
170168
}
171169

172-
if (!activeThreads.containsKey(threadId)) {
170+
boolean removed = activeThreads.remove(threadId);
171+
if (removed) {
172+
logger.trace("Deregistered thread '{}' Active threads: {}", threadId, activeThreads.size());
173+
} else {
173174
logger.warn("Thread '{}' not active, cannot deregister", threadId);
174-
return;
175175
}
176176

177-
ThreadType type = activeThreads.remove(threadId);
178-
logger.trace("Deregistered thread '{}' ({}). Active threads: {}", threadId, type, activeThreads.size());
179-
180177
if (activeThreads.isEmpty()) {
181178
logger.info("No active threads remaining - suspending execution");
182179
suspendExecution();

sdk/src/main/java/com/amazonaws/lambda/durable/execution/OperationContext.java renamed to sdk/src/main/java/com/amazonaws/lambda/durable/execution/ThreadContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
package com.amazonaws.lambda.durable.execution;
44

55
/** Holds the current thread's execution context. */
6-
public record OperationContext(String contextId, ThreadType threadType) {}
6+
public record ThreadContext(String threadId, ThreadType threadType) {}

sdk/src/main/java/com/amazonaws/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import com.amazonaws.lambda.durable.exception.SerDesException;
1010
import com.amazonaws.lambda.durable.exception.UnrecoverableDurableExecutionException;
1111
import com.amazonaws.lambda.durable.execution.ExecutionManager;
12-
import com.amazonaws.lambda.durable.execution.ThreadType;
12+
import com.amazonaws.lambda.durable.execution.ThreadContext;
1313
import com.amazonaws.lambda.durable.serde.SerDes;
1414
import com.amazonaws.lambda.durable.util.ExceptionHelper;
1515
import java.time.Duration;
@@ -107,29 +107,32 @@ protected boolean isOperationCompleted() {
107107

108108
/** Waits for the operation to complete and suspends the execution if no active thread is running */
109109
protected Operation waitForOperationCompletion() {
110-
111-
var context = executionManager.getCurrentContext();
110+
var threadContext = getCurrentThreadContext();
112111

113112
// It's important that we synchronize access to the future. Otherwise, a race condition could happen if the
114113
// completionFuture is completed by a user thread (a step or child context thread) when the execution here
115114
// is between `isOperationCompleted` and `thenRun`.
116115
synchronized (completionFuture) {
117116
if (!isOperationCompleted()) {
118117
// Operation not done yet
119-
logger.debug("get() on {} attempting to deregister context: {}", getType(), context.contextId());
118+
logger.trace(
119+
"deregistering thread {} when waiting for operation {} ({}) to complete ({})",
120+
threadContext.threadId(),
121+
getOperation(),
122+
getType(),
123+
completionFuture);
120124

121125
// Add a completion stage to completionFuture so that when the completionFuture is completed,
122126
// it will register the current Context thread synchronously to make sure it is always registered
123127
// strictly before the execution thread (Step or child context) is deregistered.
124-
completionFuture.thenRun(() -> registerActiveThread(context.contextId(), context.threadType()));
128+
completionFuture.thenRun(() -> registerActiveThread(threadContext.threadId()));
125129

126130
// Deregister the current thread to allow suspension
127-
deregisterActiveThread(context.contextId());
131+
deregisterActiveThread(threadContext.threadId());
128132
}
129133
}
130134

131135
// Block until operation completes. No-op if the future is already completed.
132-
logger.trace("Waiting for operation to finish {} ({})", getOperationId(), completionFuture);
133136
completionFuture.join();
134137

135138
// Get result based on status
@@ -186,12 +189,16 @@ protected void deregisterActiveThread(String threadId) {
186189
executionManager.deregisterActiveThread(threadId);
187190
}
188191

189-
protected void registerActiveThread(String threadId, ThreadType threadType) {
190-
executionManager.registerActiveThread(threadId, threadType);
192+
protected void registerActiveThread(String threadId) {
193+
executionManager.registerActiveThread(threadId);
194+
}
195+
196+
protected ThreadContext getCurrentThreadContext() {
197+
return executionManager.getCurrentThreadContext();
191198
}
192199

193-
protected void setCurrentContext(String stepThreadId, ThreadType step) {
194-
executionManager.setCurrentContext(stepThreadId, step);
200+
protected void setCurrentThreadContext(ThreadContext threadContext) {
201+
executionManager.setCurrentThreadContext(threadContext);
195202
}
196203

197204
// polling and checkpointing

sdk/src/main/java/com/amazonaws/lambda/durable/operation/StepOperation.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.amazonaws.lambda.durable.exception.UnrecoverableDurableExecutionException;
1313
import com.amazonaws.lambda.durable.execution.ExecutionManager;
1414
import com.amazonaws.lambda.durable.execution.SuspendExecutionException;
15+
import com.amazonaws.lambda.durable.execution.ThreadContext;
1516
import com.amazonaws.lambda.durable.execution.ThreadType;
1617
import com.amazonaws.lambda.durable.logging.DurableLogger;
1718
import com.amazonaws.lambda.durable.util.ExceptionHelper;
@@ -102,13 +103,13 @@ private void executeStepLogic(int attempt) {
102103
var stepThreadId = getOperationId() + "-step";
103104

104105
// Register step thread as active BEFORE executor runs (prevents suspension when handler deregisters)
105-
// thread local OperationContext is set inside the executor since that's where the step actually runs
106-
registerActiveThread(stepThreadId, ThreadType.STEP);
106+
// thread local ThreadContext is set inside the executor since that's where the step actually runs
107+
registerActiveThread(stepThreadId);
107108

108109
// Execute user code in customer-configured executor
109110
userExecutor.execute(() -> {
110-
// Set thread local OperationContext on the executor thread
111-
setCurrentContext(stepThreadId, ThreadType.STEP);
111+
// Set thread local ThreadContext on the executor thread
112+
setCurrentThreadContext(new ThreadContext(stepThreadId, ThreadType.STEP));
112113
// Set operation context for logging in this thread
113114
durableLogger.setOperationContext(getOperationId(), getName(), attempt);
114115
try {

sdk/src/test/java/com/amazonaws/lambda/durable/operation/BaseDurableOperationTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import com.amazonaws.lambda.durable.exception.NonDeterministicExecutionException;
2121
import com.amazonaws.lambda.durable.exception.SerDesException;
2222
import com.amazonaws.lambda.durable.execution.ExecutionManager;
23-
import com.amazonaws.lambda.durable.execution.OperationContext;
23+
import com.amazonaws.lambda.durable.execution.ThreadContext;
2424
import com.amazonaws.lambda.durable.execution.ThreadType;
2525
import com.amazonaws.lambda.durable.serde.JacksonSerDes;
2626
import com.amazonaws.lambda.durable.serde.SerDes;
@@ -54,7 +54,7 @@ class BaseDurableOperationTest {
5454
@BeforeEach
5555
void setUp() {
5656
executionManager = mock(ExecutionManager.class);
57-
when(executionManager.getCurrentContext()).thenReturn(new OperationContext(CONTEXT_ID, ThreadType.CONTEXT));
57+
when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(CONTEXT_ID, ThreadType.CONTEXT));
5858
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(OPERATION);
5959
}
6060

@@ -128,7 +128,7 @@ public String get() {
128128
Operation.builder().status(OperationStatus.SUCCEEDED).build());
129129
assertEquals(RESULT, future.get());
130130
verify(executionManager).deregisterActiveThread(CONTEXT_ID);
131-
verify(executionManager).registerActiveThread(CONTEXT_ID, ThreadType.CONTEXT);
131+
verify(executionManager).registerActiveThread(CONTEXT_ID);
132132
}
133133
}
134134

@@ -151,7 +151,7 @@ public String get() {
151151

152152
op.execute();
153153
verify(executionManager, never()).deregisterActiveThread(CONTEXT_ID);
154-
verify(executionManager, never()).registerActiveThread(CONTEXT_ID, ThreadType.CONTEXT);
154+
verify(executionManager, never()).registerActiveThread(CONTEXT_ID);
155155
}
156156

157157
@Test

sdk/src/test/java/com/amazonaws/lambda/durable/operation/CallbackOperationTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.amazonaws.lambda.durable.exception.CallbackTimeoutException;
1313
import com.amazonaws.lambda.durable.exception.SerDesException;
1414
import com.amazonaws.lambda.durable.execution.ExecutionManager;
15+
import com.amazonaws.lambda.durable.execution.ThreadContext;
1516
import com.amazonaws.lambda.durable.execution.ThreadType;
1617
import com.amazonaws.lambda.durable.serde.JacksonSerDes;
1718
import com.amazonaws.lambda.durable.serde.SerDes;
@@ -71,7 +72,7 @@ private ExecutionManager createExecutionManager(List<Operation> initialOperation
7172
"test-token",
7273
initialState,
7374
DurableConfig.builder().withDurableExecutionClient(client).build());
74-
executionManager.setCurrentContext("Root", ThreadType.CONTEXT);
75+
executionManager.setCurrentThreadContext(new ThreadContext("Root", ThreadType.CONTEXT));
7576
return executionManager;
7677
}
7778

sdk/src/test/java/com/amazonaws/lambda/durable/operation/InvokeOperationTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import com.amazonaws.lambda.durable.exception.InvokeStoppedException;
1515
import com.amazonaws.lambda.durable.exception.InvokeTimedOutException;
1616
import com.amazonaws.lambda.durable.execution.ExecutionManager;
17-
import com.amazonaws.lambda.durable.execution.OperationContext;
17+
import com.amazonaws.lambda.durable.execution.ThreadContext;
1818
import com.amazonaws.lambda.durable.execution.ThreadType;
1919
import com.amazonaws.lambda.durable.serde.JacksonSerDes;
2020
import org.junit.jupiter.api.BeforeEach;
@@ -32,7 +32,7 @@ class InvokeOperationTest {
3232
@BeforeEach
3333
void setUp() {
3434
executionManager = mock(ExecutionManager.class);
35-
when(executionManager.getCurrentContext()).thenReturn(new OperationContext("root", ThreadType.CONTEXT));
35+
when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("root", ThreadType.CONTEXT));
3636
}
3737

3838
@Test

sdk/src/test/java/com/amazonaws/lambda/durable/operation/StepOperationTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import com.amazonaws.lambda.durable.exception.StepFailedException;
1212
import com.amazonaws.lambda.durable.exception.StepInterruptedException;
1313
import com.amazonaws.lambda.durable.execution.ExecutionManager;
14-
import com.amazonaws.lambda.durable.execution.OperationContext;
14+
import com.amazonaws.lambda.durable.execution.ThreadContext;
1515
import com.amazonaws.lambda.durable.execution.ThreadType;
1616
import com.amazonaws.lambda.durable.logging.DurableLogger;
1717
import com.amazonaws.lambda.durable.serde.JacksonSerDes;
@@ -31,7 +31,7 @@ class StepOperationTest {
3131

3232
private ExecutionManager createMockExecutionManager() {
3333
var executionManager = mock(ExecutionManager.class);
34-
when(executionManager.getCurrentContext()).thenReturn(new OperationContext("handler", ThreadType.CONTEXT));
34+
when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("handler", ThreadType.CONTEXT));
3535
return executionManager;
3636
}
3737

@@ -67,7 +67,7 @@ void getDoesNotThrowWhenCalledFromHandlerContext() {
6767
.stepDetails(StepDetails.builder().result("\"cached-result\"").build())
6868
.build();
6969
var executionManager = mock(ExecutionManager.class);
70-
when(executionManager.getCurrentContext()).thenReturn(new OperationContext("handler", ThreadType.CONTEXT));
70+
when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("handler", ThreadType.CONTEXT));
7171
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(op);
7272

7373
var operation = new StepOperation<>(

sdk/src/test/java/com/amazonaws/lambda/durable/operation/WaitOperationTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import static org.mockito.Mockito.when;
1111

1212
import com.amazonaws.lambda.durable.execution.ExecutionManager;
13-
import com.amazonaws.lambda.durable.execution.OperationContext;
13+
import com.amazonaws.lambda.durable.execution.ThreadContext;
1414
import com.amazonaws.lambda.durable.execution.ThreadType;
1515
import java.time.Duration;
1616
import org.junit.jupiter.api.BeforeEach;
@@ -82,7 +82,7 @@ void getDoesNotThrowWhenCalledFromHandlerContext() {
8282
.status(OperationStatus.SUCCEEDED)
8383
.waitDetails(WaitDetails.builder().build())
8484
.build();
85-
when(executionManager.getCurrentContext()).thenReturn(new OperationContext(CONTEXT_ID, ThreadType.CONTEXT));
85+
when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(CONTEXT_ID, ThreadType.CONTEXT));
8686
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(op);
8787

8888
var operation = new WaitOperation(OPERATION_ID, OPERATION_NAME, Duration.ofSeconds(10), executionManager);
@@ -99,7 +99,7 @@ void getSucceededWhenStarted() {
9999
.name(OPERATION_NAME)
100100
.status(OperationStatus.SUCCEEDED)
101101
.build();
102-
when(executionManager.getCurrentContext()).thenReturn(new OperationContext(CONTEXT_ID, ThreadType.CONTEXT));
102+
when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(CONTEXT_ID, ThreadType.CONTEXT));
103103
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(op);
104104

105105
var operation = new WaitOperation(OPERATION_ID, OPERATION_NAME, Duration.ofSeconds(10), executionManager);

0 commit comments

Comments
 (0)