Skip to content

Commit a19fc08

Browse files
wangyb-AAlex Wang
andauthored
feat: [Parallel] Add parallel result (aws#246)
* Add parallel result * Update examples --------- Co-authored-by: Alex Wang <wangyb@amazon.com>
1 parent d77f201 commit a19fc08

12 files changed

Lines changed: 245 additions & 89 deletions

File tree

examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import software.amazon.lambda.durable.DurableFuture;
99
import software.amazon.lambda.durable.DurableHandler;
1010
import software.amazon.lambda.durable.ParallelConfig;
11+
import software.amazon.lambda.durable.model.ParallelResult;
1112

1213
/**
1314
* Example demonstrating parallel branch execution with the Durable Execution SDK.
@@ -38,8 +39,9 @@ public Output handleRequest(Input input, DurableContext context) {
3839
var config = ParallelConfig.builder().build();
3940

4041
var futures = new ArrayList<DurableFuture<String>>(items.size());
42+
var parallel = context.parallel("process-items", config);
4143

42-
try (var parallel = context.parallel("process-items", config)) {
44+
try (parallel) {
4345
for (var item : items) {
4446
var future = parallel.branch("process-" + item, String.class, branchCtx -> {
4547
branchCtx.getLogger().info("Processing item: {}", item);
@@ -49,7 +51,12 @@ public Output handleRequest(Input input, DurableContext context) {
4951
}
5052
} // join() called here via AutoCloseable
5153

52-
logger.info("All branches complete, collecting results");
54+
ParallelResult parallelResult = parallel.get();
55+
logger.info(
56+
"Parallel complete: total={}, succeeded={}, failed={}",
57+
parallelResult.getTotalBranches(),
58+
parallelResult.getSucceededBranches(),
59+
parallelResult.getFailedBranches());
5360

5461
var results = futures.stream().map(DurableFuture::get).toList();
5562

examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import software.amazon.lambda.durable.DurableHandler;
1010
import software.amazon.lambda.durable.ParallelConfig;
1111
import software.amazon.lambda.durable.StepConfig;
12+
import software.amazon.lambda.durable.model.ParallelResult;
1213
import software.amazon.lambda.durable.retry.RetryStrategies;
1314

1415
/**
@@ -24,22 +25,24 @@
2425
public class ParallelFailureToleranceExample
2526
extends DurableHandler<ParallelFailureToleranceExample.Input, ParallelFailureToleranceExample.Output> {
2627

27-
public record Input(List<String> services, int toleratedFailures) {}
28+
public record Input(List<String> services, int toleratedFailures, int minSuccessful) {}
2829

29-
public record Output(List<String> succeeded, List<String> failed) {}
30+
public record Output(int succeeded, int failed) {}
3031

3132
@Override
3233
public Output handleRequest(Input input, DurableContext context) {
3334
var logger = context.getLogger();
3435
logger.info("Starting parallel execution with toleratedFailureCount={}", input.toleratedFailures());
3536

3637
var config = ParallelConfig.builder()
38+
.minSuccessful(input.minSuccessful())
3739
.toleratedFailureCount(input.toleratedFailures())
3840
.build();
3941

4042
var futures = new ArrayList<DurableFuture<String>>(input.services().size());
43+
var parallel = context.parallel("call-services", config);
4144

42-
try (var parallel = context.parallel("call-services", config)) {
45+
try (parallel) {
4346
for (var service : input.services()) {
4447
var future = parallel.branch("call-" + service, String.class, branchCtx -> {
4548
return branchCtx.step(
@@ -59,20 +62,17 @@ public Output handleRequest(Input input, DurableContext context) {
5962
}
6063
}
6164

62-
var succeeded = new ArrayList<String>();
63-
var failed = new ArrayList<String>();
65+
ParallelResult parallelResult = parallel.get();
66+
logger.info(
67+
"Parallel complete: succeeded={}, failed={}, status={}",
68+
parallelResult.getSucceededBranches(),
69+
parallelResult.getFailedBranches(),
70+
parallelResult.getCompletionStatus().isSucceeded() ? "succeeded" : "failed");
6471

65-
for (int i = 0; i < futures.size(); i++) {
66-
try {
67-
var result = futures.get(i).get();
68-
succeeded.add(result);
69-
} catch (Exception e) {
70-
failed.add(input.services().get(i));
71-
logger.info("Branch failed for service {}: {}", input.services().get(i), e.getMessage());
72-
}
73-
}
72+
var succeeded = parallelResult.getSucceededBranches();
73+
var failed = parallelResult.getFailedBranches();
7474

75-
logger.info("Completed: {} succeeded, {} failed", succeeded.size(), failed.size());
75+
logger.info("Completed: {} succeeded, {} failed", succeeded, failed);
7676
return new Output(succeeded, failed);
7777
}
7878
}

examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import software.amazon.lambda.durable.DurableFuture;
1010
import software.amazon.lambda.durable.DurableHandler;
1111
import software.amazon.lambda.durable.ParallelConfig;
12+
import software.amazon.lambda.durable.model.ParallelResult;
1213

1314
/**
1415
* Example demonstrating parallel branches where some branches include wait operations.
@@ -29,7 +30,7 @@ public class ParallelWithWaitExample
2930

3031
public record Input(String userId, String message) {}
3132

32-
public record Output(List<String> deliveries) {}
33+
public record Output(List<String> deliveries, int success, int faiure) {}
3334

3435
@Override
3536
public Output handleRequest(Input input, DurableContext context) {
@@ -38,8 +39,9 @@ public Output handleRequest(Input input, DurableContext context) {
3839

3940
var config = ParallelConfig.builder().build();
4041
var futures = new ArrayList<DurableFuture<String>>(3);
42+
var parallel = context.parallel("notify", config);
4143

42-
try (var parallel = context.parallel("notify", config)) {
44+
try (parallel) {
4345

4446
// Branch 1: email — no wait, deliver immediately
4547
futures.add(parallel.branch("email", String.class, ctx -> {
@@ -60,10 +62,12 @@ public Output handleRequest(Input input, DurableContext context) {
6062
}));
6163
}
6264

65+
ParallelResult result = parallel.get();
66+
6367
var deliveries = futures.stream().map(DurableFuture::get).toList();
6468
logger.info("All {} notifications delivered", deliveries.size());
6569
// Test replay
6670
context.wait("wait for finalization", Duration.ofSeconds(5));
67-
return new Output(deliveries);
71+
return new Output(deliveries, result.getSucceededBranches(), result.getFailedBranches());
6872
}
6973
}

examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,28 @@ void succeedsWhenFailuresAreWithinTolerance() {
1717
var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler);
1818

1919
// 2 good services, 1 bad — toleratedFailureCount=1 so the parallel op still succeeds
20-
var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1);
20+
var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1, -1);
2121
var result = runner.runUntilComplete(input);
2222

2323
assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus());
2424

2525
var output = result.getResult(ParallelFailureToleranceExample.Output.class);
26-
assertEquals(2, output.succeeded().size());
27-
assertEquals(1, output.failed().size());
28-
assertTrue(output.succeeded().contains("ok:svc-a"));
29-
assertTrue(output.succeeded().contains("ok:svc-c"));
30-
assertTrue(output.failed().contains("bad-svc-b"));
26+
assertEquals(2, output.succeeded());
27+
assertEquals(1, output.failed());
3128
}
3229

3330
@Test
3431
void succeedsWhenAllBranchesSucceed() {
3532
var handler = new ParallelFailureToleranceExample();
3633
var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler);
3734

38-
var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2);
35+
var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2, -1);
3936
var result = runner.runUntilComplete(input);
4037

4138
assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus());
4239

4340
var output = result.getResult(ParallelFailureToleranceExample.Output.class);
44-
assertEquals(3, output.succeeded().size());
45-
assertTrue(output.failed().isEmpty());
41+
assertEquals(3, output.succeeded());
4642
}
4743

4844
@Test
@@ -51,13 +47,13 @@ void failsWhenFailuresExceedTolerance() {
5147
var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler);
5248

5349
// 2 bad services, toleratedFailureCount=1 — second failure exceeds tolerance
54-
var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "bad-svc-c"), 1);
50+
var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "bad-svc-c"), 1, 2);
5551
var result = runner.runUntilComplete(input);
5652

5753
assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus());
5854

5955
var output = result.getResult(ParallelFailureToleranceExample.Output.class);
60-
assertEquals(2, output.failed().size());
61-
assertEquals(1, output.succeeded().size());
56+
assertEquals(2, output.failed());
57+
assertEquals(1, output.succeeded());
6258
}
6359
}

examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExampleTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ void completesAfterManuallyAdvancingWaits() {
2929

3030
var output = result.getResult(ParallelWithWaitExample.Output.class);
3131
assertEquals(List.of("email:world", "sms:world", "push:world"), output.deliveries());
32+
assertEquals(3, output.success());
3233
}
3334
}

sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,25 @@
33
package software.amazon.lambda.durable;
44

55
import java.util.Objects;
6+
import java.util.concurrent.atomic.AtomicBoolean;
67
import java.util.function.Function;
8+
import software.amazon.lambda.durable.model.ParallelResult;
79
import software.amazon.lambda.durable.operation.ParallelOperation;
810

911
/** User-facing context for managing parallel branch execution within a durable function. */
10-
public class ParallelContext implements AutoCloseable {
12+
public class ParallelContext implements AutoCloseable, DurableFuture<ParallelResult> {
1113

12-
private final ParallelOperation<?> parallelOperation;
14+
private final ParallelOperation parallelOperation;
1315
private final DurableContext durableContext;
14-
private boolean joined;
16+
private final AtomicBoolean joined = new AtomicBoolean(false);
1517

1618
/**
1719
* Creates a new ParallelContext.
1820
*
1921
* @param parallelOperation the underlying parallel operation managing concurrency
2022
* @param durableContext the durable context for creating child operations
2123
*/
22-
public ParallelContext(ParallelOperation<?> parallelOperation, DurableContext durableContext) {
24+
public ParallelContext(ParallelOperation parallelOperation, DurableContext durableContext) {
2325
this.parallelOperation = Objects.requireNonNull(parallelOperation, "parallelOperation cannot be null");
2426
this.durableContext = Objects.requireNonNull(durableContext, "durableContext cannot be null");
2527
}
@@ -49,7 +51,7 @@ public <T> DurableFuture<T> branch(String name, Class<T> resultType, Function<Du
4951
* @throws IllegalStateException if called after {@link #join()}
5052
*/
5153
public <T> DurableFuture<T> branch(String name, TypeToken<T> resultType, Function<DurableContext, T> func) {
52-
if (joined) {
54+
if (joined.get()) {
5355
throw new IllegalStateException("Cannot add branches after join() has been called");
5456
}
5557
return parallelOperation.addItem(
@@ -66,11 +68,23 @@ public <T> DurableFuture<T> branch(String name, TypeToken<T> resultType, Functio
6668
* @throws software.amazon.lambda.durable.exception.ConcurrencyExecutionException if failure threshold exceeded
6769
*/
6870
public void join() {
69-
if (joined) {
71+
if (!joined.compareAndSet(false, true)) {
7072
return;
7173
}
72-
joined = true;
73-
parallelOperation.get();
74+
parallelOperation.join();
75+
}
76+
77+
/**
78+
* Blocks until the parallel operation completes and returns the {@link ParallelResult}.
79+
*
80+
* <p>Calling {@code get()} implicitly calls {@code join()} if it has not been called yet.
81+
*
82+
* @return the {@link ParallelResult} summarising branch outcomes
83+
*/
84+
@Override
85+
public ParallelResult get() {
86+
joined.set(true);
87+
return parallelOperation.get();
7488
}
7589

7690
/**

sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,9 +558,8 @@ public ParallelContext parallel(String name, ParallelConfig config) {
558558
Objects.requireNonNull(config, "config cannot be null");
559559
var operationId = nextOperationId();
560560

561-
var parallelOp = new ParallelOperation<>(
561+
var parallelOp = new ParallelOperation(
562562
OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL),
563-
TypeToken.get(Void.class),
564563
getDurableConfig().getSerDes(),
565564
this,
566565
config.maxConcurrency(),
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable.model;
4+
5+
/**
6+
* Summary result of a parallel operation.
7+
*
8+
* <p>Captures the aggregate outcome of a parallel execution: how many branches were registered, how many succeeded, how
9+
* many failed, and why the operation completed.
10+
*/
11+
public class ParallelResult {
12+
13+
private final int totalBranches;
14+
private final int succeededBranches;
15+
private final int failedBranches;
16+
private final ConcurrencyCompletionStatus completionStatus;
17+
18+
public ParallelResult(
19+
int totalBranches,
20+
int succeededBranches,
21+
int failedBranches,
22+
ConcurrencyCompletionStatus completionStatus) {
23+
this.totalBranches = totalBranches;
24+
this.succeededBranches = succeededBranches;
25+
this.failedBranches = failedBranches;
26+
this.completionStatus = completionStatus;
27+
}
28+
29+
/** Returns the total number of branches registered before {@code join()} was called. */
30+
public int getTotalBranches() {
31+
return totalBranches;
32+
}
33+
34+
/** Returns the number of branches that completed without throwing. */
35+
public int getSucceededBranches() {
36+
return succeededBranches;
37+
}
38+
39+
/** Returns the number of branches that threw an exception. */
40+
public int getFailedBranches() {
41+
return failedBranches;
42+
}
43+
44+
/** Returns the status indicating why the parallel operation completed. */
45+
public ConcurrencyCompletionStatus getCompletionStatus() {
46+
return completionStatus;
47+
}
48+
}

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public abstract class ConcurrencyOperation<T> extends BaseDurableOperation<T> {
5454
private final Set<String> completedOperations = Collections.synchronizedSet(new HashSet<String>());
5555
private OperationIdGenerator operationIdGenerator;
5656
private final DurableContextImpl rootContext;
57+
private ConcurrencyCompletionStatus completionStatus;
5758

5859
protected ConcurrencyOperation(
5960
OperationIdentifier operationIdentifier,
@@ -203,9 +204,9 @@ public void onItemComplete(ChildContextOperation<?> child) {
203204
}
204205
runningCount.decrementAndGet();
205206

206-
var status = canComplete();
207-
if (status != null) {
208-
handleComplete(status);
207+
this.completionStatus = canComplete();
208+
if (this.completionStatus != null) {
209+
handleComplete(this.completionStatus);
209210
} else {
210211
executeNextItemIfAllowed();
211212
}
@@ -245,17 +246,13 @@ private void handleComplete(ConcurrencyCompletionStatus status) {
245246
* Blocks the calling thread until the concurrency operation reaches a terminal state. Validates item count, handles
246247
* zero-branch case, then delegates to {@code waitForOperationCompletion()} from BaseDurableOperation.
247248
*/
248-
protected void join() {
249+
public void join() {
249250
validateItemCount();
250251
isJoined.set(true);
251-
if (childOperations.isEmpty()) {
252-
return;
253-
}
254-
255252
synchronized (this) {
256-
var status = canComplete();
257-
if (status != null) {
258-
handleComplete(status);
253+
this.completionStatus = canComplete();
254+
if (this.completionStatus != null) {
255+
handleComplete(this.completionStatus);
259256
}
260257
}
261258

@@ -274,6 +271,10 @@ protected int getTotalItems() {
274271
return childOperations.size();
275272
}
276273

274+
protected ConcurrencyCompletionStatus getCompletionStatus() {
275+
return completionStatus;
276+
}
277+
277278
protected List<ChildContextOperation<?>> getChildOperations() {
278279
return Collections.unmodifiableList(childOperations);
279280
}

0 commit comments

Comments
 (0)