Skip to content

Commit 7adca38

Browse files
committed
Synchronize step ids with thread names and de-register / register caller threads in .get().
1 parent 478dc63 commit 7adca38

10 files changed

Lines changed: 165 additions & 63 deletions

File tree

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable.examples;
4+
5+
import com.amazonaws.lambda.durable.DurableContext;
6+
import com.amazonaws.lambda.durable.DurableHandler;
7+
import com.amazonaws.lambda.durable.StepConfig;
8+
import com.amazonaws.lambda.durable.retry.RetryStrategies;
9+
10+
/**
11+
* Example demonstrating nested step calling with stepAsync.
12+
*
13+
* <p>This example shows how to:
14+
*
15+
* <ol>
16+
* <li>Create an async step that performs long-running work
17+
* <li>Use the result of that async step within another step
18+
* <li>Properly coordinate execution between multiple steps
19+
* </ol>
20+
*/
21+
public class NestedStepExample extends DurableHandler<Object, String> {
22+
23+
@Override
24+
public String handleRequest(Object input, DurableContext context) {
25+
// Step 1: Create an async step that performs long-running work
26+
var durableFuture1 = context.stepAsync(
27+
"async-step",
28+
String.class,
29+
() -> {
30+
try {
31+
Thread.sleep(10000); // Simulate 10 seconds of work
32+
} catch (InterruptedException e) {
33+
Thread.currentThread().interrupt();
34+
throw new RuntimeException("Interrupted", e);
35+
}
36+
return "async-result";
37+
},
38+
StepConfig.builder()
39+
.retryStrategy(RetryStrategies.Presets.DEFAULT)
40+
.build());
41+
42+
// Step 2: Process the result of the async step
43+
var result = context.step("process-result", String.class, () -> {
44+
var asyncResult = durableFuture1.get(); // Wait for async step to complete
45+
return asyncResult + "-processed";
46+
});
47+
48+
return result;
49+
}
50+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable.examples;
4+
5+
import static org.junit.jupiter.api.Assertions.*;
6+
7+
import com.amazonaws.lambda.durable.model.ExecutionStatus;
8+
import com.amazonaws.lambda.durable.testing.LocalDurableTestRunner;
9+
import org.junit.jupiter.api.Test;
10+
11+
class NestedStepExampleTest {
12+
13+
@Test
14+
void testNestedStepExample() {
15+
// Test nested step calling with stepAsync
16+
var handler = new NestedStepExample();
17+
var runner = LocalDurableTestRunner.create(Object.class, handler);
18+
19+
var result = runner.run("test-input");
20+
21+
assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus());
22+
assertEquals("async-result-processed", result.getResult(String.class));
23+
}
24+
25+
@Test
26+
void testReplay() {
27+
// Test replay behavior with nested steps
28+
var handler = new NestedStepExample();
29+
var runner = LocalDurableTestRunner.create(Object.class, handler);
30+
31+
var input = "replay-test";
32+
33+
// First execution
34+
var result1 = runner.run(input);
35+
36+
// Second execution (replay)
37+
var result2 = runner.run(input);
38+
39+
assertEquals(result1.getStatus(), result2.getStatus());
40+
assertEquals(result1.getResult(String.class), result2.getResult(String.class));
41+
}
42+
}

examples/template.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,28 @@ Resources:
186186
DockerContext: ../
187187
DockerTag: durable-examples
188188

189+
NestedStepExampleFunction:
190+
Type: AWS::Serverless::Function
191+
Properties:
192+
PackageType: Image
193+
FunctionName: nested-step-example
194+
ImageConfig:
195+
Command: ["com.amazonaws.lambda.durable.examples.NestedStepExample::handleRequest"]
196+
DurableConfig:
197+
ExecutionTimeout: 300
198+
RetentionPeriodInDays: 7
199+
Policies:
200+
- Statement:
201+
- Effect: Allow
202+
Action:
203+
- lambda:CheckpointDurableExecutions
204+
- lambda:GetDurableExecutionState
205+
Resource: !Sub "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:nested-step-example"
206+
Metadata:
207+
Dockerfile: examples/Dockerfile
208+
DockerContext: ../
209+
DockerTag: durable-examples
210+
189211
Outputs:
190212
SimpleStepExampleFunction:
191213
Description: Simple Step Example Function ARN
@@ -250,3 +272,11 @@ Outputs:
250272
CustomConfigExampleFunctionName:
251273
Description: Custom Config Example Function Name
252274
Value: !Ref CustomConfigExampleFunction
275+
276+
NestedStepExampleFunction:
277+
Description: Nested Step Example Function ARN
278+
Value: !GetAtt NestedStepExampleFunction.Arn
279+
280+
NestedStepExampleFunctionName:
281+
Description: Nested Step Example Function Name
282+
Value: !Ref NestedStepExampleFunction

sdk/src/main/java/com/amazonaws/lambda/durable/DurableConfig.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,11 @@ public static final class Builder {
183183
private Builder() {}
184184

185185
/**
186-
* Sets a custom LambdaClient for production use. Use this method to customize the AWS SDK client
187-
* with specific regions, credentials, timeouts, or retry policies.
186+
* Sets a custom LambdaClient for production use. Use this method to customize the AWS SDK client with specific
187+
* regions, credentials, timeouts, or retry policies.
188188
*
189189
* <p>Example:
190+
*
190191
* <pre>{@code
191192
* LambdaClient lambdaClient = LambdaClient.builder()
192193
* .region(Region.US_WEST_2)
@@ -211,9 +212,9 @@ public Builder withLambdaClient(LambdaClient lambdaClient) {
211212
/**
212213
* Sets a custom DurableExecutionClient.
213214
*
214-
* <p><b>Note:</b> This method is primarily intended for testing with mock clients
215-
* (e.g., {@code LocalMemoryExecutionClient}). For production use with a custom AWS SDK client,
216-
* prefer {@link #withLambdaClient(LambdaClient)}.
215+
* <p><b>Note:</b> This method is primarily intended for testing with mock clients (e.g.,
216+
* {@code LocalMemoryExecutionClient}). For production use with a custom AWS SDK client, prefer
217+
* {@link #withLambdaClient(LambdaClient)}.
217218
*
218219
* @param durableExecutionClient Custom DurableExecutionClient instance
219220
* @return This builder

sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import com.amazonaws.lambda.durable.exception.NonDeterministicExecutionException;
66
import com.amazonaws.lambda.durable.execution.ExecutionManager;
7-
import com.amazonaws.lambda.durable.execution.ThreadType;
87
import com.amazonaws.lambda.durable.operation.StepOperation;
98
import com.amazonaws.lambda.durable.operation.WaitOperation;
109
import com.amazonaws.lambda.durable.retry.RetryStrategies;
@@ -22,17 +21,17 @@ public class DurableContext {
2221
private final SerDes serDes;
2322
private final Context lambdaContext;
2423
private final AtomicInteger operationCounter;
24+
private final String uniqueThreadName;
2525

2626
DurableContext(ExecutionManager executionManager, SerDes serDes, Context lambdaContext) {
2727
this.executionManager = executionManager;
2828
this.serDes = serDes;
2929
this.lambdaContext = lambdaContext;
3030
this.operationCounter = new AtomicInteger(0);
31+
this.uniqueThreadName = Thread.currentThread().getName(); // Auto-detect handler thread
3132

32-
// Register root context thread as active
33-
// TODO: Once we implement child contexts, the threadId needs to be the ID of
34-
// the child context
35-
executionManager.registerActiveThread("Root", ThreadType.CONTEXT);
33+
// Register context thread as active
34+
executionManager.registerActiveThread(uniqueThreadName);
3635
}
3736

3837
public <T> T step(String name, Class<T> resultType, Supplier<T> func) {

sdk/src/main/java/com/amazonaws/lambda/durable/DurableExecutor.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ public static <I, O> DurableExecutionOutput execute(
7171
var serDes = config.getSerDes();
7272
var userInput = extractUserInput(executionOp, serDes, inputType);
7373

74-
// Create context
75-
var context = new DurableContext(executionManager, serDes, lambdaContext);
76-
7774
try {
78-
var handlerFuture = CompletableFuture.supplyAsync(() -> handler.apply(userInput, context), executor);
75+
var handlerFuture = CompletableFuture.supplyAsync(
76+
() -> {
77+
// Create context in the executor thread so it detects the correct thread name
78+
var context = new DurableContext(executionManager, serDes, lambdaContext);
79+
return handler.apply(userInput, context);
80+
},
81+
executor);
7982

8083
// Get suspend future from ExecutionManager. If this future completes, it
8184
// indicates

sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import java.time.Instant;
99
import java.util.Collections;
1010
import java.util.HashMap;
11+
import java.util.HashSet;
1112
import java.util.List;
1213
import java.util.Map;
14+
import java.util.Set;
1315
import java.util.concurrent.CompletableFuture;
1416
import java.util.concurrent.ConcurrentHashMap;
1517
import java.util.concurrent.Executor;
@@ -47,7 +49,7 @@ public class ExecutionManager {
4749
private final String durableExecutionArn;
4850

4951
// ===== Thread Coordination =====
50-
private final Map<String, ThreadType> activeThreads = Collections.synchronizedMap(new HashMap<>());
52+
private final Set<String> activeThreads = Collections.synchronizedSet(new HashSet<>());
5153
private final Map<String, Phaser> openPhasers = Collections.synchronizedMap(new HashMap<>());
5254
private final CompletableFuture<Void> suspendExecutionFuture = new CompletableFuture<>();
5355

@@ -137,19 +139,15 @@ public Operation getExecutionOperation() {
137139

138140
// ===== Thread Coordination =====
139141

140-
public void registerActiveThread(String threadId, ThreadType threadType) {
141-
if (activeThreads.containsKey(threadId)) {
142-
logger.debug("Thread '{}' ({}) already registered as active", threadId, threadType);
142+
public void registerActiveThread(String threadId) {
143+
if (activeThreads.contains(threadId)) {
144+
logger.debug("Thread '{}' already registered as active", threadId);
143145
return;
144146
}
145147

146148
synchronized (this) {
147-
activeThreads.put(threadId, threadType);
148-
logger.debug(
149-
"Registered thread '{}' ({}) as active. Active threads: {}",
150-
threadId,
151-
threadType,
152-
activeThreads.size());
149+
activeThreads.add(threadId);
150+
logger.debug("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
153151
}
154152
}
155153

@@ -159,14 +157,17 @@ public void deregisterActiveThread(String threadId) {
159157
return;
160158
}
161159

162-
if (!activeThreads.containsKey(threadId)) {
163-
logger.warn("Thread '{}' not active, cannot deregister", threadId);
160+
if (!activeThreads.contains(threadId)) {
161+
logger.warn(
162+
"Thread '{}' not active, cannot deregister. Current thread: {}",
163+
threadId,
164+
Thread.currentThread().getName());
164165
return;
165166
}
166167

167168
synchronized (this) {
168-
ThreadType type = activeThreads.remove(threadId);
169-
logger.debug("Deregistered thread '{}' ({}). Active threads: {}", threadId, type, activeThreads.size());
169+
activeThreads.remove(threadId);
170+
logger.debug("Deregistered thread '{}'. Active threads: {}", threadId, activeThreads.size());
170171

171172
if (activeThreads.isEmpty()) {
172173
logger.info("No active threads remaining - suspending execution");

sdk/src/main/java/com/amazonaws/lambda/durable/execution/ThreadType.java

Lines changed: 0 additions & 24 deletions
This file was deleted.

sdk/src/main/java/com/amazonaws/lambda/durable/operation/StepOperation.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import com.amazonaws.lambda.durable.exception.StepInterruptedException;
1010
import com.amazonaws.lambda.durable.execution.ExecutionManager;
1111
import com.amazonaws.lambda.durable.execution.ExecutionPhase;
12-
import com.amazonaws.lambda.durable.execution.ThreadType;
1312
import com.amazonaws.lambda.durable.serde.SerDes;
1413
import java.time.Duration;
1514
import java.time.Instant;
@@ -159,10 +158,12 @@ public void execute() {
159158
private void executeStepLogic(int attempt) {
160159
// Register step thread as active
161160
var stepThreadId = operationId + "-step";
162-
executionManager.registerActiveThread(stepThreadId, ThreadType.STEP);
161+
executionManager.registerActiveThread(stepThreadId);
163162

164163
// Execute in managed executor
165164
executionManager.getManagedExecutor().execute(() -> {
165+
// Set thread name to match the step thread ID for consistent logging and identification
166+
Thread.currentThread().setName(stepThreadId);
166167
try {
167168
// Check if we need to send START
168169
var existing = executionManager.getOperation(operationId);
@@ -290,16 +291,16 @@ public T get() {
290291
phaser.register();
291292

292293
// Deregister current thread - allows suspension
293-
// TODO: The threadId here should be the (potential childContext) thread id that
294-
// is calling .get()
295-
executionManager.deregisterActiveThread("Root");
294+
String currentThreadName = Thread.currentThread().getName();
295+
logger.debug("StepOperation.get() attempting to deregister thread: {}", currentThreadName);
296+
executionManager.deregisterActiveThread(currentThreadName);
296297

297298
// Block until operation completes
298299
logger.debug("Waiting for operation to finish {} (Phaser: {})", operationId, phaser);
299300
phaser.arriveAndAwaitAdvance(); // Wait for phase 0
300301

301302
// Reactivate current thread
302-
executionManager.registerActiveThread("Root", ThreadType.CONTEXT);
303+
executionManager.registerActiveThread(currentThreadName);
303304

304305
// Complete phase 1
305306
phaser.arriveAndDeregister();
@@ -346,8 +347,7 @@ public T get() {
346347
throw new StepFailedException(
347348
String.format(
348349
"Step failed with error of type %s. Message: %s",
349-
errorType,
350-
op.stepDetails().error().errorMessage()),
350+
errorType, op.stepDetails().error().errorMessage()),
351351
null,
352352
// Preserve original stack trace
353353
StepFailedException.deserializeStackTrace(

sdk/src/main/java/com/amazonaws/lambda/durable/operation/WaitOperation.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import com.amazonaws.lambda.durable.execution.ExecutionManager;
66
import com.amazonaws.lambda.durable.execution.ExecutionPhase;
7-
import com.amazonaws.lambda.durable.execution.ThreadType;
87
import java.time.Duration;
98
import java.time.Instant;
109
import java.util.concurrent.Phaser;
@@ -100,14 +99,15 @@ public Void get() {
10099

101100
// Deregister current thread - THIS is where suspension can happen!
102101
// If no other threads are active, this will throw SuspendExecutionException
103-
executionManager.deregisterActiveThread("Root");
102+
String callingThreadName = Thread.currentThread().getName();
103+
executionManager.deregisterActiveThread(callingThreadName);
104104

105105
// Complete the wait phaser immediately (we don't actually wait in Lambda)
106106
// The backend handles the wait duration
107107
phaser.arriveAndAwaitAdvance(); // Phase 0 -> 1
108108

109109
// Reactivate current thread
110-
executionManager.registerActiveThread("Root", ThreadType.CONTEXT);
110+
executionManager.registerActiveThread(callingThreadName);
111111

112112
// Complete phase 1
113113
phaser.arriveAndDeregister();

0 commit comments

Comments
 (0)