Skip to content

Commit fe12ac7

Browse files
authored
Merge pull request #9 from aws/phipag/nested-steps
feat(nested-steps): Stop hardcoding Root context thread
2 parents 6e07d68 + 8035c8b commit fe12ac7

9 files changed

Lines changed: 225 additions & 38 deletions

File tree

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable;
4+
5+
import static org.junit.jupiter.api.Assertions.*;
6+
7+
import com.amazonaws.lambda.durable.model.ExecutionStatus;
8+
import com.amazonaws.lambda.durable.testing.LocalDurableTestRunner;
9+
import org.junit.jupiter.api.Test;
10+
11+
/** Tests that nested step calling is properly rejected. */
12+
class NestedStepIntegrationTest {
13+
14+
@Test
15+
void nestedStepCallingThrowsIllegalStateException() {
16+
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
17+
// outer-step's supplier calls context.step() which internally calls stepAsync().get()
18+
// The get() is called from the outer step's thread (named "1-step"), triggering the check
19+
var future = context.stepAsync("outer-step", String.class, () -> {
20+
return context.step("inner-step", String.class, () -> "inner-result");
21+
});
22+
return future.get();
23+
});
24+
25+
var result = runner.run("test");
26+
27+
assertEquals(ExecutionStatus.FAILED, result.getStatus());
28+
var errorMessage = result.getError().get().errorMessage();
29+
assertTrue(
30+
errorMessage.contains("Nested step calling is not supported"),
31+
"Expected error about nested step calling, got: " + errorMessage);
32+
}
33+
34+
@Test
35+
void awaitingAsyncStepInsideSyncStepThrowsIllegalStateException() {
36+
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
37+
// Start async step from handler thread
38+
var asyncFuture = context.stepAsync("async-step", String.class, () -> "async-result");
39+
40+
// Sync step tries to await the async step's result inside its supplier
41+
return context.step("sync-step", String.class, () -> {
42+
// This get() is called from sync-step's thread ("2-step"), which is not allowed
43+
return "combined: " + asyncFuture.get();
44+
});
45+
});
46+
47+
var result = runner.run("test");
48+
49+
assertEquals(ExecutionStatus.FAILED, result.getStatus());
50+
var errorMessage = result.getError().get().errorMessage();
51+
assertTrue(
52+
errorMessage.contains("Nested step calling is not supported"),
53+
"Expected error about nested step calling, got: " + errorMessage);
54+
}
55+
}

sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ public class DurableContext {
2727
private final AtomicInteger operationCounter;
2828
private final DurableLogger logger;
2929

30-
DurableContext(ExecutionManager executionManager, SerDes serDes, Context lambdaContext, LoggerConfig loggerConfig) {
30+
DurableContext(
31+
ExecutionManager executionManager,
32+
SerDes serDes,
33+
Context lambdaContext,
34+
LoggerConfig loggerConfig,
35+
String contextId) {
3136
this.executionManager = executionManager;
3237
this.serDes = serDes;
3338
this.lambdaContext = lambdaContext;
@@ -41,9 +46,11 @@ public class DurableContext {
4146
loggerConfig.suppressReplayLogs());
4247

4348
// Register root context thread as active
44-
// TODO: Once we implement child contexts, the threadId needs to be the ID of
45-
// the child context
46-
executionManager.registerActiveThread("Root", ThreadType.CONTEXT);
49+
executionManager.registerActiveThreadWithContext(contextId, ThreadType.CONTEXT);
50+
}
51+
52+
DurableContext(ExecutionManager executionManager, SerDes serDes, Context lambdaContext, LoggerConfig loggerConfig) {
53+
this(executionManager, serDes, lambdaContext, loggerConfig, "Root");
4754
}
4855

4956
public <T> T step(String name, Class<T> resultType, Supplier<T> func) {

sdk/src/main/java/com/amazonaws/lambda/durable/DurableExecutor.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,15 @@ public static <I, O> DurableExecutionOutput execute(
7070
var serDes = config.getSerDes();
7171
var userInput = extractUserInput(executionOp, serDes, inputType);
7272

73-
// Create context
74-
var context = new DurableContext(executionManager, serDes, lambdaContext, config.getLoggerConfig());
75-
7673
try {
77-
var handlerFuture = CompletableFuture.supplyAsync(() -> handler.apply(userInput, context), executor);
74+
var handlerFuture = CompletableFuture.supplyAsync(
75+
() -> {
76+
// Create context in the executor thread so it detects the correct thread name
77+
var context =
78+
new DurableContext(executionManager, serDes, lambdaContext, config.getLoggerConfig());
79+
return handler.apply(userInput, context);
80+
},
81+
executor);
7882

7983
// Get suspend future from ExecutionManager. If this future completes, it
8084
// indicates

sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public class ExecutionManager {
5050

5151
// ===== Thread Coordination =====
5252
private final Map<String, ThreadType> activeThreads = Collections.synchronizedMap(new HashMap<>());
53+
private static final ThreadLocal<OperationContext> currentContext = new ThreadLocal<>();
5354
private final Map<String, Phaser> openPhasers = Collections.synchronizedMap(new HashMap<>());
5455
private final CompletableFuture<Void> suspendExecutionFuture = new CompletableFuture<>();
5556

@@ -164,20 +165,51 @@ public Operation getExecutionOperation() {
164165

165166
// ===== Thread Coordination =====
166167

167-
public void registerActiveThread(String threadId, ThreadType threadType) {
168+
public void registerActiveThreadWithContext(String threadId, ThreadType threadType) {
168169
if (activeThreads.containsKey(threadId)) {
169170
logger.trace("Thread '{}' ({}) already registered as active", threadId, threadType);
170171
return;
171172
}
173+
activeThreads.put(threadId, threadType);
174+
currentContext.set(new OperationContext(threadId, threadType));
175+
logger.trace(
176+
"Registered thread '{}' ({}) as active. Active threads: {}",
177+
threadId,
178+
threadType,
179+
activeThreads.size());
180+
}
172181

173-
synchronized (this) {
174-
activeThreads.put(threadId, threadType);
175-
logger.trace(
176-
"Registered thread '{}' ({}) as active. Active threads: {}",
177-
threadId,
178-
threadType,
179-
activeThreads.size());
182+
/**
183+
* Registers a thread as active without setting the thread local OperationContext. Use this when registration must
184+
* happen on a different thread than execution. Call setCurrentContext() on the execution thread to set the local
185+
* OperationContext.
186+
*
187+
* @see OperationContext
188+
*/
189+
public void registerActiveThread(String threadId, ThreadType threadType) {
190+
if (activeThreads.containsKey(threadId)) {
191+
logger.trace("Thread '{}' ({}) already registered as active", threadId, threadType);
192+
return;
180193
}
194+
activeThreads.put(threadId, threadType);
195+
logger.trace(
196+
"Registered thread '{}' ({}) as active (no context). Active threads: {}",
197+
threadId,
198+
threadType,
199+
activeThreads.size());
200+
}
201+
202+
/**
203+
* Sets the current thread's context. Use after registerActiveThreadWithoutContext() when the execution thread is
204+
* different from the registration thread.
205+
*/
206+
public void setCurrentContext(String contextId, ThreadType threadType) {
207+
currentContext.set(new OperationContext(contextId, threadType));
208+
}
209+
210+
/** Returns the current thread's context, or null if not set. */
211+
public OperationContext getCurrentContext() {
212+
return currentContext.get();
181213
}
182214

183215
public void deregisterActiveThread(String threadId) {
@@ -191,15 +223,14 @@ public void deregisterActiveThread(String threadId) {
191223
return;
192224
}
193225

194-
synchronized (this) {
195-
ThreadType type = activeThreads.remove(threadId);
196-
logger.trace("Deregistered thread '{}' ({}). Active threads: {}", threadId, type, activeThreads.size());
226+
ThreadType type = activeThreads.remove(threadId);
227+
currentContext.remove();
228+
logger.trace("Deregistered thread '{}' ({}). Active threads: {}", threadId, type, activeThreads.size());
197229

198-
if (activeThreads.isEmpty()) {
199-
logger.info("No active threads remaining - suspending execution");
200-
suspendExecutionFuture.complete(null);
201-
throw new SuspendExecutionException();
202-
}
230+
if (activeThreads.isEmpty()) {
231+
logger.info("No active threads remaining - suspending execution");
232+
suspendExecutionFuture.complete(null);
233+
throw new SuspendExecutionException();
203234
}
204235
}
205236

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable.execution;
4+
5+
/** Holds the current thread's execution context. */
6+
public record OperationContext(String contextId, ThreadType threadType) {}

sdk/src/main/java/com/amazonaws/lambda/durable/operation/StepOperation.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,17 @@ public void execute() {
163163
}
164164

165165
private void executeStepLogic(int attempt) {
166-
// Register step thread as active
166+
// TODO: Modify this logic when child contexts are introduced such that the child context id is in this key
167167
var stepThreadId = operationId + "-step";
168+
169+
// Register step thread as active BEFORE executor runs (prevents suspension when handler deregisters)
170+
// thread local OperationContext is set inside the executor since that's where the step actually runs
168171
executionManager.registerActiveThread(stepThreadId, ThreadType.STEP);
169172

170173
// Execute in managed executor
171174
executionManager.getManagedExecutor().execute(() -> {
175+
// Set thread local OperationContext on the executor thread
176+
executionManager.setCurrentContext(stepThreadId, ThreadType.STEP);
172177
// Set operation context for logging in this thread
173178
durableLogger.setOperationContext(operationId, name, attempt);
174179
try {
@@ -290,6 +295,15 @@ private void handleStepFailure(Throwable error, int attempt) {
290295

291296
@Override
292297
public T get() {
298+
// Get current context from ThreadLocal
299+
var currentContext = executionManager.getCurrentContext();
300+
301+
// Nested steps are not supported
302+
if (currentContext.threadType() == ThreadType.STEP) {
303+
throw new IllegalStateException("Nested step calling is not supported. Cannot call get() on step '" + name
304+
+ "' from within another step's execution.");
305+
}
306+
293307
// If we are in a replay where the operation is already complete (SUCCEEDED /
294308
// FAILED), the Phaser will be
295309
// advanced in .execute() already and we don't block but return the result
@@ -298,17 +312,16 @@ public T get() {
298312
// Operation not done yet
299313
phaser.register();
300314

301-
// Deregister current thread - allows suspension
302-
// TODO: The threadId here should be the (potential childContext) thread id that
303-
// is calling .get()
304-
executionManager.deregisterActiveThread("Root");
315+
// Deregister current context - allows suspension
316+
logger.debug("StepOperation.get() attempting to deregister context: {}", currentContext.contextId());
317+
executionManager.deregisterActiveThread(currentContext.contextId());
305318

306319
// Block until operation completes
307320
logger.trace("Waiting for operation to finish {} (Phaser: {})", operationId, phaser);
308321
phaser.arriveAndAwaitAdvance(); // Wait for phase 0
309322

310-
// Reactivate current thread
311-
executionManager.registerActiveThread("Root", ThreadType.CONTEXT);
323+
// Reactivate current context
324+
executionManager.registerActiveThreadWithContext(currentContext.contextId(), currentContext.threadType());
312325

313326
// Complete phase 1
314327
phaser.arriveAndDeregister();

sdk/src/main/java/com/amazonaws/lambda/durable/operation/WaitOperation.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import com.amazonaws.lambda.durable.execution.ExecutionManager;
66
import com.amazonaws.lambda.durable.execution.ExecutionPhase;
7-
import com.amazonaws.lambda.durable.execution.ThreadType;
87
import java.time.Duration;
98
import java.time.Instant;
109
import java.util.concurrent.Phaser;
@@ -97,16 +96,19 @@ public Void get() {
9796
if (phaser.getPhase() == ExecutionPhase.RUNNING.getValue()) {
9897
phaser.register();
9998

99+
// Get current context from ThreadLocal
100+
var currentContext = executionManager.getCurrentContext();
101+
100102
// Deregister current thread - THIS is where suspension can happen!
101103
// If no other threads are active, this will throw SuspendExecutionException
102-
executionManager.deregisterActiveThread("Root");
104+
executionManager.deregisterActiveThread(currentContext.contextId());
103105

104106
// Complete the wait phaser immediately (we don't actually wait in Lambda)
105107
// The backend handles the wait duration
106108
phaser.arriveAndAwaitAdvance(); // Phase 0 -> 1
107109

108110
// Reactivate current thread
109-
executionManager.registerActiveThread("Root", ThreadType.CONTEXT);
111+
executionManager.registerActiveThreadWithContext(currentContext.contextId(), currentContext.threadType());
110112

111113
// Complete phase 1
112114
phaser.arriveAndDeregister();

sdk/src/test/java/com/amazonaws/lambda/durable/StepConfigTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import com.amazonaws.lambda.durable.retry.RetryStrategies;
1010
import com.amazonaws.lambda.durable.serde.JacksonSerDes;
11-
import com.amazonaws.lambda.durable.serde.SerDes;
1211
import org.junit.jupiter.api.Test;
1312

1413
class StepConfigTest {
@@ -32,7 +31,7 @@ void testBuilderWithoutRetryStrategy() {
3231
@Test
3332
void testBuilderChaining() {
3433
var strategy = RetryStrategies.Presets.NO_RETRY;
35-
SerDes customSerDes = new JacksonSerDes();
34+
var customSerDes = new JacksonSerDes();
3635

3736
var config = StepConfig.builder()
3837
.retryStrategy(strategy)
@@ -61,7 +60,7 @@ void testSemanticsDefaultsToAtLeastOnce() {
6160

6261
@Test
6362
void testBuilderWithCustomSerDes() {
64-
SerDes customSerDes = new JacksonSerDes();
63+
var customSerDes = new JacksonSerDes();
6564

6665
var config = StepConfig.builder().serDes(customSerDes).build();
6766

@@ -86,7 +85,7 @@ void testBuilderWithNullSerDes() {
8685
@Test
8786
void testBuilderWithAllOptions() {
8887
var strategy = RetryStrategies.Presets.DEFAULT;
89-
SerDes customSerDes = new JacksonSerDes();
88+
var customSerDes = new JacksonSerDes();
9089

9190
var config = StepConfig.builder()
9291
.retryStrategy(strategy)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable.operation;
4+
5+
import static org.junit.jupiter.api.Assertions.*;
6+
import static org.mockito.Mockito.*;
7+
8+
import com.amazonaws.lambda.durable.execution.ExecutionManager;
9+
import com.amazonaws.lambda.durable.execution.OperationContext;
10+
import com.amazonaws.lambda.durable.execution.ThreadType;
11+
import com.amazonaws.lambda.durable.logging.DurableLogger;
12+
import com.amazonaws.lambda.durable.serde.JacksonSerDes;
13+
import java.util.concurrent.Phaser;
14+
import org.junit.jupiter.api.Test;
15+
16+
class StepOperationTest {
17+
18+
@Test
19+
void getThrowsIllegalStateExceptionWhenCalledFromStepContext() {
20+
var executionManager = mock(ExecutionManager.class);
21+
var phaser = new Phaser(1);
22+
when(executionManager.startPhaser(any())).thenReturn(phaser);
23+
when(executionManager.getCurrentContext()).thenReturn(new OperationContext("1-step", ThreadType.STEP));
24+
25+
var operation = new StepOperation<>(
26+
"1",
27+
"test-step",
28+
() -> "result",
29+
String.class,
30+
null,
31+
executionManager,
32+
mock(DurableLogger.class),
33+
new JacksonSerDes());
34+
35+
var ex = assertThrows(IllegalStateException.class, operation::get);
36+
assertTrue(ex.getMessage().contains("Nested step calling is not supported"));
37+
assertTrue(ex.getMessage().contains("test-step"));
38+
}
39+
40+
@Test
41+
void getDoesNotThrowWhenCalledFromHandlerContext() {
42+
var executionManager = mock(ExecutionManager.class);
43+
var phaser = new Phaser(1);
44+
phaser.arriveAndDeregister(); // Advance to phase 1 to skip blocking
45+
when(executionManager.startPhaser(any())).thenReturn(phaser);
46+
when(executionManager.getCurrentContext()).thenReturn(new OperationContext("handler", ThreadType.CONTEXT));
47+
when(executionManager.getOperation("1"))
48+
.thenReturn(software.amazon.awssdk.services.lambda.model.Operation.builder()
49+
.id("1")
50+
.name("test-step")
51+
.status(software.amazon.awssdk.services.lambda.model.OperationStatus.SUCCEEDED)
52+
.stepDetails(software.amazon.awssdk.services.lambda.model.StepDetails.builder()
53+
.result("\"cached-result\"")
54+
.build())
55+
.build());
56+
57+
var operation = new StepOperation<>(
58+
"1",
59+
"test-step",
60+
() -> "result",
61+
String.class,
62+
null,
63+
executionManager,
64+
mock(DurableLogger.class),
65+
new JacksonSerDes());
66+
67+
var result = operation.get();
68+
assertEquals("cached-result", result);
69+
}
70+
}

0 commit comments

Comments
 (0)