Skip to content

Commit 216aae2

Browse files
committed
feat: Use per-context replay status
1 parent dbd1108 commit 216aae2

8 files changed

Lines changed: 109 additions & 24 deletions

File tree

sdk/src/main/java/software/amazon/lambda/durable/context/BaseContext.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ public interface BaseContext extends AutoCloseable {
4444
/** Gets the context name for this context. Null for root context. */
4545
String getContextName();
4646

47-
/** Returns whether this context is currently in replay mode. */
48-
boolean isReplaying();
47+
/**
48+
* Returns whether this context is currently replaying based on per-context tracking. Checks whether the next
49+
* operation in this specific context already exists in checkpoint storage, providing accurate replay status even
50+
* when multiple contexts run concurrently.
51+
*/
52+
boolean isReplayingContext();
4953

5054
/** Closes this context. */
5155
void close();

sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ public abstract class BaseContextImpl implements AutoCloseable, BaseContext {
1515
private final String contextName;
1616
private final ThreadType threadType;
1717

18-
private boolean isReplaying;
19-
2018
/**
2119
* Creates a new BaseContext instance.
2220
*
@@ -39,7 +37,6 @@ protected BaseContextImpl(
3937
this.lambdaContext = lambdaContext;
4038
this.contextId = contextId;
4139
this.contextName = contextName;
42-
this.isReplaying = executionManager.hasOperationsForContext(contextId);
4340
this.threadType = threadType;
4441
}
4542

@@ -97,16 +94,12 @@ public ExecutionManager getExecutionManager() {
9794
return executionManager;
9895
}
9996

100-
/** Returns whether this context is currently in replay mode. */
101-
@Override
102-
public boolean isReplaying() {
103-
return isReplaying;
104-
}
105-
10697
/**
107-
* Transitions this context from replay to execution mode. Called when the first un-cached operation is encountered.
98+
* Returns whether this context is currently in replay mode. The default implementation returns false. Subclasses
99+
* that track per-context replay status (like DurableContextImpl) override this.
108100
*/
109-
public void setExecutionMode() {
110-
this.isReplaying = false;
101+
@Override
102+
public boolean isReplayingContext() {
103+
return false;
111104
}
112105
}

sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public class DurableContextImpl extends BaseContextImpl implements DurableContex
6464
private final DurableContextImpl parentContext;
6565
private final boolean isVirtual;
6666
private volatile DurableLogger logger;
67+
private boolean replayMode;
6768

6869
/** Shared initialization — sets all fields. */
6970
private DurableContextImpl(
@@ -78,6 +79,8 @@ private DurableContextImpl(
7879
operationIdGenerator = new OperationIdGenerator(contextId);
7980
this.parentContext = parentContext;
8081
this.isVirtual = isVirtual;
82+
// Initialize replay mode by checking if the next operation (first in this context) exists in storage
83+
this.replayMode = executionManager.hasOperation(operationIdGenerator.peekNextOperationId());
8184
}
8285

8386
/**
@@ -143,6 +146,7 @@ public <T> DurableFuture<T> stepAsync(
143146
config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build();
144147
}
145148
var operationId = nextOperationId();
149+
updateReplayStatus();
146150

147151
// Create and start step operation with TypeToken
148152
var operation = new StepOperation<>(
@@ -159,6 +163,7 @@ public DurableFuture<Void> waitAsync(String name, Duration duration) {
159163
ParameterValidator.validateOperationName(name);
160164

161165
var operationId = nextOperationId();
166+
updateReplayStatus();
162167

163168
// Create and start wait operation
164169
var operation =
@@ -184,6 +189,7 @@ public <T, U> DurableFuture<T> invokeAsync(
184189
.build();
185190
}
186191
var operationId = nextOperationId();
192+
updateReplayStatus();
187193

188194
// Create and start invoke operation
189195
var operation = new InvokeOperation<>(
@@ -205,6 +211,7 @@ public <T> DurableCallbackFuture<T> createCallback(String name, TypeToken<T> res
205211
config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build();
206212
}
207213
var operationId = nextOperationId();
214+
updateReplayStatus();
208215

209216
var operation = new CallbackOperation<>(
210217
OperationIdentifier.of(operationId, name, OperationType.CALLBACK), resultType, config, this);
@@ -246,6 +253,7 @@ private <T> DurableFuture<T> runInChildContextAsync(
246253
}
247254

248255
var operationId = nextOperationId();
256+
updateReplayStatus();
249257

250258
var operation = new ChildContextOperation<>(
251259
OperationIdentifier.of(operationId, name, OperationType.CONTEXT, subType),
@@ -275,6 +283,7 @@ public <I, O> DurableFuture<MapResult<O>> mapAsync(
275283
// Convert to List for deterministic index-based access
276284
var itemList = List.copyOf(items);
277285
var operationId = nextOperationId();
286+
updateReplayStatus();
278287

279288
var operation = new MapOperation<>(
280289
OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.MAP),
@@ -291,6 +300,7 @@ public <I, O> DurableFuture<MapResult<O>> mapAsync(
291300
public ParallelDurableFuture parallel(String name, ParallelConfig config) {
292301
Objects.requireNonNull(config, "config cannot be null");
293302
var operationId = nextOperationId();
303+
updateReplayStatus();
294304

295305
var parallelOp = new ParallelOperation(
296306
OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL),
@@ -362,6 +372,7 @@ public <T> DurableFuture<T> waitForConditionAsync(
362372
config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build();
363373
}
364374
var operationId = nextOperationId();
375+
updateReplayStatus();
365376

366377
var operation = new WaitForConditionOperation<>(operationId, name, checkFunc, resultType, config, this);
367378

@@ -454,6 +465,28 @@ public void close() {
454465
}
455466
}
456467

468+
/**
469+
* Returns whether this context is currently in replay mode based on per-context tracking. A context is replaying
470+
* when its next operation already exists in checkpoint storage.
471+
*/
472+
@Override
473+
public boolean isReplayingContext() {
474+
return replayMode;
475+
}
476+
477+
/**
478+
* Checks if the next operation exists in checkpoint storage and transitions out of replay mode if it does not. This
479+
* is called before each operation to maintain accurate per-context replay status.
480+
*/
481+
public void updateReplayStatus() {
482+
if (!replayMode) {
483+
return;
484+
}
485+
if (!getExecutionManager().hasOperation(operationIdGenerator.peekNextOperationId())) {
486+
replayMode = false;
487+
}
488+
}
489+
457490
/**
458491
* Get the next operationId. Returns a globally unique operation ID by hashing a sequential operation counter. For
459492
* root contexts, the counter value is hashed directly (e.g. "1", "2", "3"). For child contexts, the values are

sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,16 @@ public boolean hasOperationsForContext(String parentId) {
178178
return operationStorage.values().stream().anyMatch(op -> Objects.equals(op.parentId(), parentId));
179179
}
180180

181+
/**
182+
* Checks whether an operation with the given ID exists in checkpoint storage.
183+
*
184+
* @param operationId the operation ID to check
185+
* @return true if the operation exists
186+
*/
187+
public boolean hasOperation(String operationId) {
188+
return operationStorage.containsKey(operationId);
189+
}
190+
181191
// ===== Thread Coordination =====
182192
/** Sets the current thread's ThreadContext (threadId and threadType). Called when a user thread is started. */
183193
public void setCurrentThreadContext(ThreadContext threadContext) {

sdk/src/main/java/software/amazon/lambda/durable/execution/OperationIdGenerator.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,13 @@ public String nextOperationId() {
4545
var counter = String.valueOf(operationCounter.incrementAndGet());
4646
return hashOperationId(operationIdPrefix + counter);
4747
}
48+
49+
/**
50+
* Returns the operation ID that would be generated by the next call to {@link #nextOperationId()} without
51+
* incrementing the counter. Used to check whether the next operation already exists in checkpoint storage.
52+
*/
53+
public String peekNextOperationId() {
54+
var counter = String.valueOf(operationCounter.get() + 1);
55+
return hashOperationId(operationIdPrefix + counter);
56+
}
4857
}

sdk/src/main/java/software/amazon/lambda/durable/logging/DurableLogger.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ public void error(String message, Throwable t) {
9393
}
9494

9595
private boolean shouldSuppress() {
96-
return context.getDurableConfig().getLoggerConfig().suppressReplayLogs()
97-
&& context.getExecutionManager().isReplaying();
96+
return context.getDurableConfig().getLoggerConfig().suppressReplayLogs() && context.isReplayingContext();
9897
}
9998

10099
private void log(Runnable logAction) {

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,6 @@ public void execute() {
134134
}
135135
replay(existing);
136136
} else {
137-
if (durableContext.isReplaying()) {
138-
this.durableContext.setExecutionMode();
139-
}
140137
start();
141138
}
142139
}

sdk/src/test/java/software/amazon/lambda/durable/logging/DurableLoggerTest.java

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import software.amazon.lambda.durable.TestContext;
1414
import software.amazon.lambda.durable.context.DurableContextImpl;
1515
import software.amazon.lambda.durable.execution.ExecutionManager;
16+
import software.amazon.lambda.durable.execution.OperationIdGenerator;
1617

1718
class DurableLoggerTest {
1819
private static final String EXECUTION_NAME = "exec-123";
@@ -42,7 +43,7 @@ void setUp() {
4243
}
4344

4445
private DurableLogger createLogger(Mode mode, Suppression suppression) {
45-
when(mockExecutionManager.isReplaying()).thenReturn(mode == Mode.REPLAYING);
46+
when(mockExecutionManager.hasOperation(anyString())).thenReturn(mode == Mode.REPLAYING);
4647
return new DurableLogger(mockLogger, createDurableContext(REQUEST_ID, suppression));
4748
}
4849

@@ -104,7 +105,7 @@ void setsExecutionMdcInConstructor() {
104105
void setStepThreadPropertiesSetsMdc() {
105106
try (MockedStatic<MDC> mdcMock = mockStatic(MDC.class)) {
106107
mdcMock.clearInvocations();
107-
when(mockExecutionManager.isReplaying()).thenReturn(false);
108+
when(mockExecutionManager.hasOperation(anyString())).thenReturn(false);
108109
var logger = new DurableLogger(
109110
mockLogger,
110111
createDurableContext(REQUEST_ID, Suppression.ENABLED)
@@ -130,13 +131,18 @@ void clearThreadPropertiesRemovesMdc() {
130131

131132
@Test
132133
void replayModeTransitionAllowsSubsequentLogs() {
133-
when(mockExecutionManager.isReplaying()).thenReturn(true, false);
134-
var logger = new DurableLogger(mockLogger, createDurableContext(REQUEST_ID, Suppression.ENABLED));
134+
when(mockExecutionManager.hasOperation(anyString())).thenReturn(true);
135+
var durableContext = createDurableContext(REQUEST_ID, Suppression.ENABLED);
136+
var logger = new DurableLogger(mockLogger, durableContext);
135137

136138
// During replay - suppressed
137139
logger.info("suppressed");
138140
verify(mockLogger, never()).info(anyString(), any(Object[].class));
139141

142+
// Simulate next operation not existing in storage — triggers transition out of replay
143+
when(mockExecutionManager.hasOperation(anyString())).thenReturn(false);
144+
durableContext.updateReplayStatus();
145+
140146
// After transition to execution mode - logged
141147
logger.info("logged after transition");
142148
verify(mockLogger).info(eq("logged after transition"), any(Object[].class));
@@ -163,10 +169,44 @@ void allLogLevelsDelegateCorrectly() {
163169
verify(mockLogger).error("error with exception", exception);
164170
}
165171

172+
@Test
173+
void concurrentContextsHaveIndependentReplayState() {
174+
var rootNextOp = OperationIdGenerator.hashOperationId("1");
175+
var childANextOp = OperationIdGenerator.hashOperationId("child-a-1");
176+
var childBNextOp = OperationIdGenerator.hashOperationId("child-b-1");
177+
178+
when(mockExecutionManager.hasOperation(rootNextOp)).thenReturn(true);
179+
when(mockExecutionManager.hasOperation(childANextOp)).thenReturn(false);
180+
when(mockExecutionManager.hasOperation(childBNextOp)).thenReturn(true);
181+
182+
var rootContext = createDurableContext(REQUEST_ID, Suppression.ENABLED);
183+
var childA = rootContext.createChildContext("child-a", "branch-a", false);
184+
var childB = rootContext.createChildContext("child-b", "branch-b", false);
185+
186+
var loggerForA = mock(Logger.class);
187+
var loggerForB = mock(Logger.class);
188+
var durableLoggerA = new DurableLogger(loggerForA, childA);
189+
var durableLoggerB = new DurableLogger(loggerForB, childB);
190+
191+
// Child A is in execution mode — logs should pass through
192+
durableLoggerA.info("from branch A");
193+
verify(loggerForA).info(eq("from branch A"), any(Object[].class));
194+
195+
// Child B is still replaying — logs should be suppressed
196+
durableLoggerB.info("from branch B");
197+
verify(loggerForB, never()).info(anyString(), any(Object[].class));
198+
199+
// After child B transitions, its logs should pass through
200+
when(mockExecutionManager.hasOperation(childBNextOp)).thenReturn(false);
201+
childB.updateReplayStatus();
202+
durableLoggerB.info("branch B after transition");
203+
verify(loggerForB).info(eq("branch B after transition"), any(Object[].class));
204+
}
205+
166206
@Test
167207
void handlesNullRequestId() {
168208
try (MockedStatic<MDC> mdcMock = mockStatic(MDC.class)) {
169-
when(mockExecutionManager.isReplaying()).thenReturn(false);
209+
when(mockExecutionManager.hasOperation(anyString())).thenReturn(false);
170210
new DurableLogger(mockLogger, createDurableContext(null, Suppression.DISABLED));
171211

172212
mdcMock.verify(() -> MDC.put(DurableLogger.MDC_EXECUTION_ARN, EXECUTION_ARN));

0 commit comments

Comments
 (0)