Skip to content

Commit a216fca

Browse files
committed
fix thread synchronization
1 parent 1a608bc commit a216fca

4 files changed

Lines changed: 50 additions & 34 deletions

File tree

sdk-integration-tests/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@
4141
<artifactId>junit-jupiter</artifactId>
4242
<scope>test</scope>
4343
</dependency>
44+
<dependency>
45+
<groupId>org.slf4j</groupId>
46+
<artifactId>slf4j-simple</artifactId>
47+
<version>${slf4j.version}</version>
48+
<scope>test</scope>
49+
</dependency>
4450
</dependencies>
4551

4652
<build>

sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,15 @@ void testParallelPartialFailure_failedBranchDoesNotPreventOthers() {
109109
void testParallelAllBranchesFail() {
110110
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
111111
var config = ParallelConfig.builder().build();
112-
var futures = new ArrayList<DurableFuture<String>>();
113112
var parallel = context.parallel("all-fail", config);
114113

115114
try (parallel) {
116-
futures.add(parallel.branch("branch-x", String.class, ctx -> {
115+
parallel.branch("branch-x", String.class, ctx -> {
117116
throw new RuntimeException("fail-x");
118-
}));
119-
futures.add(parallel.branch("branch-y", String.class, ctx -> {
117+
});
118+
parallel.branch("branch-y", String.class, ctx -> {
120119
throw new RuntimeException("fail-y");
121-
}));
120+
});
122121
}
123122

124123
var result = parallel.get();
@@ -298,7 +297,7 @@ void testStepBeforeAndAfterParallel() {
298297

299298
var config = ParallelConfig.builder().build();
300299
var futures = new ArrayList<DurableFuture<String>>();
301-
var parallel = context.parallel("middle-parallel", config);
300+
ParallelDurableFuture parallel = context.parallel("middle-parallel", config);
302301

303302
try (parallel) {
304303
futures.add(parallel.branch("branch-a", String.class, ctx -> "A"));
@@ -451,19 +450,18 @@ void testParallelUnlimitedConcurrencyWithToleratedFailureCount() {
451450
var config = ParallelConfig.builder()
452451
.completionConfig(CompletionConfig.toleratedFailureCount(1))
453452
.build();
454-
var futures = new ArrayList<DurableFuture<String>>();
455-
var parallel = context.parallel("unlimited-tolerated", config);
453+
ParallelDurableFuture parallel = context.parallel("unlimited-tolerated", config);
456454

457455
try (parallel) {
458-
futures.add(parallel.branch("branch-ok1", String.class, ctx -> "OK1"));
459-
futures.add(parallel.branch("branch-fail1", String.class, ctx -> {
456+
parallel.branch("branch-ok1", String.class, ctx -> "OK1");
457+
parallel.branch("branch-fail1", String.class, ctx -> {
460458
throw new RuntimeException("failed: fail1");
461-
}));
462-
futures.add(parallel.branch("branch-ok2", String.class, ctx -> "OK2"));
463-
futures.add(parallel.branch("branch-fail2", String.class, ctx -> {
459+
});
460+
parallel.branch("branch-ok2", String.class, ctx -> "OK2");
461+
parallel.branch("branch-fail2", String.class, ctx -> {
464462
throw new RuntimeException("failed: fail2");
465-
}));
466-
futures.add(parallel.branch("branch-ok3", String.class, ctx -> "OK3"));
463+
});
464+
parallel.branch("branch-ok3", String.class, ctx -> "OK3");
467465
}
468466

469467
var result = parallel.get();
@@ -508,7 +506,7 @@ void testParallelBranchesReturnDifferentTypes() {
508506
void testParallelResultSummary_succeededAndFailedCounts() {
509507
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
510508
var config = ParallelConfig.builder().build();
511-
var parallel = context.parallel("count-check", config);
509+
ParallelDurableFuture parallel = context.parallel("count-check", config);
512510

513511
try (parallel) {
514512
parallel.branch("ok1", String.class, ctx -> "OK1");
@@ -595,15 +593,14 @@ void testParallel50BranchesWithWaitForCallback_maxConcurrency5() {
595593

596594
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
597595
var config = ParallelConfig.builder().maxConcurrency(5).build();
598-
var futures = new ArrayList<DurableFuture<String>>();
599596
var parallel = context.parallel("50-callbacks-limited", config);
600597

601598
try (parallel) {
602599
for (int i = 0; i < branchCount; i++) {
603600
var idx = i;
604-
futures.add(parallel.branch("branch-" + i, String.class, ctx -> {
601+
parallel.branch("branch-" + i, String.class, ctx -> {
605602
return ctx.waitForCallback("cb-" + idx, String.class, (callbackId, stepCtx) -> {});
606-
}));
603+
});
607604
}
608605
}
609606

@@ -644,15 +641,14 @@ void testParallel50BranchesWithWaitForCallback_partialFailure() {
644641

645642
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
646643
var config = ParallelConfig.builder().build();
647-
var futures = new ArrayList<DurableFuture<String>>();
648644
var parallel = context.parallel("50-callbacks-partial-fail", config);
649645

650646
try (parallel) {
651647
for (int i = 0; i < branchCount; i++) {
652648
var idx = i;
653-
futures.add(parallel.branch("branch-" + i, String.class, ctx -> {
649+
parallel.branch("branch-" + i, String.class, ctx -> {
654650
return ctx.waitForCallback("approval-" + idx, String.class, (callbackId, stepCtx) -> {});
655-
}));
651+
});
656652
}
657653
}
658654

@@ -697,18 +693,17 @@ void testParallel50BranchesWithWaitForCallback_stepsBeforeAndAfterCallback() {
697693

698694
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
699695
var config = ParallelConfig.builder().build();
700-
var futures = new ArrayList<DurableFuture<String>>();
701-
var parallel = context.parallel("50-callbacks-with-steps", config);
696+
ParallelDurableFuture parallel = context.parallel("50-callbacks-with-steps", config);
702697

703698
try (parallel) {
704699
for (int i = 0; i < branchCount; i++) {
705700
var idx = i;
706-
futures.add(parallel.branch("branch-" + i, String.class, ctx -> {
701+
parallel.branch("branch-" + i, String.class, ctx -> {
707702
var before = ctx.step("prepare-" + idx, String.class, stepCtx -> "prepared-" + idx);
708703
var approval =
709704
ctx.waitForCallback("approval-" + idx, String.class, (callbackId, stepCtx) -> {});
710705
return ctx.step("finalize-" + idx, String.class, stepCtx -> before + ":" + approval + ":done");
711-
}));
706+
});
712707
}
713708
}
714709

@@ -890,20 +885,19 @@ void testParallel50BranchesMixed_callbackAndCondition() {
890885

891886
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
892887
var config = ParallelConfig.builder().build();
893-
var futures = new ArrayList<DurableFuture<String>>();
894888
var parallel = context.parallel("50-mixed", config);
895889

896890
try (parallel) {
897891
for (int i = 0; i < branchCount; i++) {
898892
var idx = i;
899893
if (i % 2 == 0) {
900894
// Even branches: waitForCallback
901-
futures.add(parallel.branch("branch-" + i, String.class, ctx -> {
895+
parallel.branch("branch-" + i, String.class, ctx -> {
902896
return ctx.waitForCallback("cb-" + idx, String.class, (callbackId, stepCtx) -> {});
903-
}));
897+
});
904898
} else {
905899
// Odd branches: waitForCondition
906-
futures.add(parallel.branch("branch-" + i, String.class, ctx -> {
900+
parallel.branch("branch-" + i, String.class, ctx -> {
907901
var strategy = WaitStrategies.<Integer>fixedDelay(5, Duration.ofSeconds(1));
908902
var wfcConfig = WaitForConditionConfig.<Integer>builder()
909903
.waitStrategy(strategy)
@@ -916,7 +910,7 @@ void testParallel50BranchesMixed_callbackAndCondition() {
916910
wfcConfig);
917911

918912
return "polled-" + polled;
919-
}));
913+
});
920914
}
921915
}
922916
}

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ protected Operation waitForOperationCompletion() {
182182
validateCurrentThreadType();
183183

184184
var threadContext = getCurrentThreadContext();
185+
CompletableFuture<?> future = completionFuture;
185186

186187
// It's important that we synchronize access to the future. Otherwise, a race condition could happen if the
187188
// completionFuture is completed by a user thread (a step or child context thread) when the execution here
@@ -201,7 +202,17 @@ protected Operation waitForOperationCompletion() {
201202
// Add a completion stage to completionFuture so that when the completionFuture is completed,
202203
// it will register the current Context thread synchronously to make sure it is always registered
203204
// strictly before the execution thread (Step or child context) is deregistered.
204-
completionFuture.thenRun(() -> registerActiveThread(threadContext.threadId()));
205+
// chain them together
206+
future = completionFuture.thenRun(() -> {
207+
logger.warn(
208+
"registering thread {} when operation {} ({}) completed ({})",
209+
threadContext.threadId(),
210+
getOperation(),
211+
getType(),
212+
completionFuture);
213+
214+
registerActiveThread(threadContext.threadId());
215+
});
205216

206217
// Deregister the current thread to allow suspension
207218
executionManager.deregisterActiveThread(threadContext.threadId());
@@ -210,7 +221,7 @@ protected Operation waitForOperationCompletion() {
210221

211222
// Block until operation completes. No-op if the future is already completed.
212223
try {
213-
completionFuture.join();
224+
future.join();
214225
} catch (Throwable throwable) {
215226
ExceptionHelper.sneakyThrow(ExceptionHelper.unwrapCompletableFuture(throwable));
216227
}
@@ -243,6 +254,7 @@ protected void runUserHandler(Runnable runnable, ThreadType threadType) {
243254
} finally {
244255
if (operationId != null) {
245256
try {
257+
logger.trace("deregistering thread {} after running user handler {}", operationId, getName());
246258
// if this is a child context or a step context, we need to
247259
// deregister the context's thread from the execution manager
248260
executionManager.deregisterActiveThread(operationId);
@@ -271,6 +283,7 @@ protected void runUserHandler(Runnable runnable, ThreadType threadType) {
271283
// 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside
272284
// the child context know which context they belong to.
273285
// registerActiveThread is idempotent (no-op if already registered).
286+
logger.trace("registering thread {} before running user handler {}", operationId, getName());
274287
registerActiveThread(operationId);
275288

276289
runningUserHandler.set(CompletableFuture.runAsync(

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,10 @@ private BaseDurableOperation waitForChildCompletion(
255255
future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
256256
// skip deregistering the current thread if there is more completed future to process
257257
if (!future.isDone()) {
258-
future.thenRun(() -> registerActiveThread(threadContext.threadId()));
258+
future = future.thenApply(o -> {
259+
registerActiveThread(threadContext.threadId());
260+
return o;
261+
});
259262
// Deregister the current thread to allow suspension
260263
executionManager.deregisterActiveThread(threadContext.threadId());
261264
}

0 commit comments

Comments
 (0)