Skip to content

Commit d4a8306

Browse files
committed
Raise exception when called from nested steps and delete example.
1 parent 3ef28a0 commit d4a8306

10 files changed

Lines changed: 220 additions & 157 deletions

File tree

examples/src/main/java/com/amazonaws/lambda/durable/examples/NestedStepExample.java

Lines changed: 0 additions & 50 deletions
This file was deleted.

examples/src/test/java/com/amazonaws/lambda/durable/examples/NestedStepExampleTest.java

Lines changed: 0 additions & 48 deletions
This file was deleted.

examples/template.yaml

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -186,27 +186,6 @@ Resources:
186186
DockerContext: ../
187187
DockerTag: durable-examples
188188

189-
NestedStepExampleFunction:
190-
Type: AWS::Serverless::Function
191-
Properties:
192-
PackageType: Image
193-
FunctionName: nested-step-example
194-
ImageConfig:
195-
Command: ["com.amazonaws.lambda.durable.examples.NestedStepExample::handleRequest"]
196-
DurableConfig:
197-
ExecutionTimeout: 300
198-
RetentionPeriodInDays: 7
199-
Policies:
200-
- Statement:
201-
- Effect: Allow
202-
Action:
203-
- lambda:CheckpointDurableExecutions
204-
- lambda:GetDurableExecutionState
205-
Resource: !Sub "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:nested-step-example"
206-
Metadata:
207-
Dockerfile: examples/Dockerfile
208-
DockerContext: ../
209-
DockerTag: durable-examples
210189

211190
Outputs:
212191
SimpleStepExampleFunction:
@@ -273,10 +252,3 @@ Outputs:
273252
Description: Custom Config Example Function Name
274253
Value: !Ref CustomConfigExampleFunction
275254

276-
NestedStepExampleFunction:
277-
Description: Nested Step Example Function ARN
278-
Value: !GetAtt NestedStepExampleFunction.Arn
279-
280-
NestedStepExampleFunctionName:
281-
Description: Nested Step Example Function Name
282-
Value: !Ref NestedStepExampleFunction
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable;
4+
5+
import static org.junit.jupiter.api.Assertions.*;
6+
7+
import com.amazonaws.lambda.durable.model.ExecutionStatus;
8+
import com.amazonaws.lambda.durable.testing.LocalDurableTestRunner;
9+
import org.junit.jupiter.api.Test;
10+
11+
/** Tests that nested step calling is properly rejected. */
12+
class NestedStepIntegrationTest {
13+
14+
@Test
15+
void nestedStepCallingThrowsIllegalStateException() {
16+
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
17+
// outer-step's supplier calls context.step() which internally calls stepAsync().get()
18+
// The get() is called from the outer step's thread (named "1-step"), triggering the check
19+
var future = context.stepAsync("outer-step", String.class, () -> {
20+
return context.step("inner-step", String.class, () -> "inner-result");
21+
});
22+
return future.get();
23+
});
24+
25+
var result = runner.run("test");
26+
27+
assertEquals(ExecutionStatus.FAILED, result.getStatus());
28+
var errorMessage = result.getError().get().errorMessage();
29+
assertTrue(
30+
errorMessage.contains("Nested step calling is not supported"),
31+
"Expected error about nested step calling, got: " + errorMessage);
32+
}
33+
34+
@Test
35+
void awaitingAsyncStepInsideSyncStepThrowsIllegalStateException() {
36+
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
37+
// Start async step from handler thread
38+
var asyncFuture = context.stepAsync("async-step", String.class, () -> "async-result");
39+
40+
// Sync step tries to await the async step's result inside its supplier
41+
return context.step("sync-step", String.class, () -> {
42+
// This get() is called from sync-step's thread ("2-step"), which is not allowed
43+
return "combined: " + asyncFuture.get();
44+
});
45+
});
46+
47+
var result = runner.run("test");
48+
49+
assertEquals(ExecutionStatus.FAILED, result.getStatus());
50+
var errorMessage = result.getError().get().errorMessage();
51+
assertTrue(
52+
errorMessage.contains("Nested step calling is not supported"),
53+
"Expected error about nested step calling, got: " + errorMessage);
54+
}
55+
}

sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import com.amazonaws.lambda.durable.exception.NonDeterministicExecutionException;
66
import com.amazonaws.lambda.durable.execution.ExecutionManager;
7+
import com.amazonaws.lambda.durable.execution.ThreadType;
78
import com.amazonaws.lambda.durable.operation.StepOperation;
89
import com.amazonaws.lambda.durable.operation.WaitOperation;
910
import com.amazonaws.lambda.durable.retry.RetryStrategies;
@@ -17,21 +18,22 @@
1718
import software.amazon.awssdk.services.lambda.model.OperationType;
1819

1920
public class DurableContext {
21+
private static final String HANDLER_CONTEXT_ID = "handler";
22+
2023
private final ExecutionManager executionManager;
2124
private final SerDes serDes;
2225
private final Context lambdaContext;
2326
private final AtomicInteger operationCounter;
24-
private final String uniqueThreadName;
2527

2628
DurableContext(ExecutionManager executionManager, SerDes serDes, Context lambdaContext) {
2729
this.executionManager = executionManager;
2830
this.serDes = serDes;
2931
this.lambdaContext = lambdaContext;
3032
this.operationCounter = new AtomicInteger(0);
31-
this.uniqueThreadName = Thread.currentThread().getName(); // Auto-detect handler thread
3233

33-
// Register context thread as active
34-
executionManager.registerActiveThread(uniqueThreadName);
34+
// Register handler context as active
35+
executionManager.registerActiveThread(HANDLER_CONTEXT_ID, ThreadType.CONTEXT);
36+
executionManager.enterContext(HANDLER_CONTEXT_ID, ThreadType.CONTEXT);
3537
}
3638

3739
public <T> T step(String name, Class<T> resultType, Supplier<T> func) {

sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
import java.time.Instant;
99
import java.util.Collections;
1010
import java.util.HashMap;
11-
import java.util.HashSet;
1211
import java.util.List;
1312
import java.util.Map;
14-
import java.util.Set;
1513
import java.util.concurrent.CompletableFuture;
1614
import java.util.concurrent.ConcurrentHashMap;
1715
import java.util.concurrent.Executor;
@@ -49,7 +47,9 @@ public class ExecutionManager {
4947
private final String durableExecutionArn;
5048

5149
// ===== Thread Coordination =====
52-
private final Set<String> activeThreads = Collections.synchronizedSet(new HashSet<>());
50+
private final Map<String, ThreadType> activeThreads = Collections.synchronizedMap(new HashMap<>());
51+
private static final ThreadLocal<String> currentContextId = new ThreadLocal<>();
52+
private static final ThreadLocal<ThreadType> currentThreadType = new ThreadLocal<>();
5353
private final Map<String, Phaser> openPhasers = Collections.synchronizedMap(new HashMap<>());
5454
private final CompletableFuture<Void> suspendExecutionFuture = new CompletableFuture<>();
5555

@@ -139,35 +139,70 @@ public Operation getExecutionOperation() {
139139

140140
// ===== Thread Coordination =====
141141

142-
public void registerActiveThread(String threadId) {
143-
if (activeThreads.contains(threadId)) {
144-
logger.debug("Thread '{}' already registered as active", threadId);
142+
public void registerActiveThread(String threadId, ThreadType threadType) {
143+
if (activeThreads.containsKey(threadId)) {
144+
logger.debug("Thread '{}' ({}) already registered as active", threadId, threadType);
145145
return;
146146
}
147147

148148
synchronized (this) {
149-
activeThreads.add(threadId);
150-
logger.debug("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size());
149+
activeThreads.put(threadId, threadType);
150+
logger.debug(
151+
"Registered thread '{}' ({}) as active. Active threads: {}",
152+
threadId,
153+
threadType,
154+
activeThreads.size());
151155
}
152156
}
153157

158+
/**
159+
* Sets the current thread's context. Call this when entering a context (handler or step).
160+
* This is separate from registerActiveThread to support cases where registration happens
161+
* on a different thread than execution.
162+
*/
163+
public void enterContext(String contextId, ThreadType threadType) {
164+
currentContextId.set(contextId);
165+
currentThreadType.set(threadType);
166+
}
167+
168+
/**
169+
* Clears the current thread's context. Call this when exiting a context.
170+
*/
171+
public void exitContext() {
172+
currentContextId.remove();
173+
currentThreadType.remove();
174+
}
175+
176+
/**
177+
* Returns the ThreadType for the current thread, or null if not registered.
178+
* This uses ThreadLocal to track context independent of thread naming.
179+
*/
180+
public ThreadType getCurrentThreadType() {
181+
return currentThreadType.get();
182+
}
183+
184+
/**
185+
* Returns the context ID for the current thread, or null if not registered.
186+
* This uses ThreadLocal to track context independent of thread naming.
187+
*/
188+
public String getCurrentContextId() {
189+
return currentContextId.get();
190+
}
191+
154192
public void deregisterActiveThread(String threadId) {
155193
// Skip if already suspended
156194
if (suspendExecutionFuture.isDone()) {
157195
return;
158196
}
159197

160-
if (!activeThreads.contains(threadId)) {
161-
logger.warn(
162-
"Thread '{}' not active, cannot deregister. Current thread: {}",
163-
threadId,
164-
Thread.currentThread().getName());
198+
if (!activeThreads.containsKey(threadId)) {
199+
logger.warn("Thread '{}' not active, cannot deregister", threadId);
165200
return;
166201
}
167202

168203
synchronized (this) {
169-
activeThreads.remove(threadId);
170-
logger.debug("Deregistered thread '{}'. Active threads: {}", threadId, activeThreads.size());
204+
ThreadType type = activeThreads.remove(threadId);
205+
logger.debug("Deregistered thread '{}' ({}). Active threads: {}", threadId, type, activeThreads.size());
171206

172207
if (activeThreads.isEmpty()) {
173208
logger.info("No active threads remaining - suspending execution");
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package com.amazonaws.lambda.durable.execution;
4+
5+
/**
6+
* Thread type enum for tracking conceptual threads in durable execution.
7+
*
8+
* <p>These are not physical OS threads, but logical threads representing different types of work in the execution.
9+
*/
10+
public enum ThreadType {
11+
CONTEXT("Context"),
12+
STEP("Step");
13+
14+
private final String displayName;
15+
16+
ThreadType(String displayName) {
17+
this.displayName = displayName;
18+
}
19+
20+
@Override
21+
public String toString() {
22+
return displayName;
23+
}
24+
}

0 commit comments

Comments
 (0)