Skip to content

Commit e460f65

Browse files
committed
add a prototype for map/parallel
1 parent 24fe07f commit e460f65

15 files changed

Lines changed: 403 additions & 6 deletions

File tree

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable.examples;
4+
5+
import java.util.List;
6+
import software.amazon.lambda.durable.ConcurrencyConfig;
7+
import software.amazon.lambda.durable.DurableContext;
8+
import software.amazon.lambda.durable.DurableHandler;
9+
import software.amazon.lambda.durable.ParallelBranchConfig;
10+
import software.amazon.lambda.durable.TypeToken;
11+
12+
/**
13+
* Simple example demonstrating basic step execution with the Durable Execution SDK.
14+
*
15+
* <p>This handler processes a greeting request through three sequential steps:
16+
*
17+
* <ol>
18+
* <li>Create greeting message
19+
* <li>Transform to uppercase
20+
* <li>Add punctuation
21+
* </ol>
22+
*/
23+
public class MapExample extends DurableHandler<GreetingRequest, String> {
24+
25+
@Override
26+
public String handleRequest(GreetingRequest input, DurableContext context) {
27+
var squared = context.mapAsync(
28+
"map example",
29+
List.of(1, 2, 3),
30+
(ctx, item, index) -> item * item,
31+
TypeToken.get(Integer.class),
32+
new ConcurrencyConfig(10, 2, 1));
33+
34+
var parallel = context.parallelAsync("parallel example", new ConcurrencyConfig(10, 2, 1));
35+
var b1 = parallel.branch("branch1", TypeToken.get(String.class), ctx -> "hello", new ParallelBranchConfig());
36+
var b2 = parallel.branch("branch2", TypeToken.get(String.class), ctx -> "world", new ParallelBranchConfig());
37+
38+
var result = parallel.get();
39+
return b1.get() + " " + b2.get();
40+
}
41+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable;
4+
5+
public class BatchResult<T> extends ParallelResult {
6+
// results/errors as well as the statistics
7+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable;
4+
5+
public record ConcurrencyConfig(int maxConcurrency, int minSuccessful, int toleratedFailureCount) {}

sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.security.NoSuchAlgorithmException;
99
import java.time.Duration;
1010
import java.util.HexFormat;
11+
import java.util.List;
1112
import java.util.Objects;
1213
import java.util.concurrent.atomic.AtomicInteger;
1314
import java.util.function.BiConsumer;
@@ -21,8 +22,11 @@
2122
import software.amazon.lambda.durable.operation.CallbackOperation;
2223
import software.amazon.lambda.durable.operation.ChildContextOperation;
2324
import software.amazon.lambda.durable.operation.InvokeOperation;
25+
import software.amazon.lambda.durable.operation.MapOperation;
26+
import software.amazon.lambda.durable.operation.ParallelOperation;
2427
import software.amazon.lambda.durable.operation.StepOperation;
2528
import software.amazon.lambda.durable.operation.WaitOperation;
29+
import software.amazon.lambda.durable.serde.JacksonSerDes;
2630
import software.amazon.lambda.durable.validation.ParameterValidator;
2731

2832
public class DurableContext extends BaseContext {
@@ -335,7 +339,7 @@ private <T> DurableFuture<T> runInChildContextAsync(
335339
var operationId = nextOperationId();
336340

337341
var operation = new ChildContextOperation<>(
338-
operationId, name, func, subType, typeToken, getDurableConfig().getSerDes(), this);
342+
operationId, name, func, subType, typeToken, getDurableConfig().getSerDes(), this, null);
339343

340344
operation.execute();
341345
return operation;
@@ -438,6 +442,28 @@ public <T> DurableFuture<T> waitForCallbackAsync(
438442
OperationSubType.WAIT_FOR_CALLBACK);
439443
}
440444

445+
// parallel operations
446+
public DurableParallelFuture parallelAsync(String name, ConcurrencyConfig config) {
447+
var operationId = nextOperationId();
448+
var operation = new ParallelOperation(operationId, name, config, this);
449+
operation.execute();
450+
return operation;
451+
}
452+
453+
// map operations
454+
public <T, I> DurableFuture<BatchResult<T>> mapAsync(
455+
String name,
456+
List<I> collection,
457+
MapFunction<I, T> func,
458+
TypeToken<T> resultTypeToken,
459+
ConcurrencyConfig config) {
460+
var operationId = nextOperationId();
461+
var operation = new MapOperation<>(
462+
operationId, name, collection, func, resultTypeToken, new JacksonSerDes(), config, this);
463+
operation.execute();
464+
return operation;
465+
}
466+
441467
// =============== accessors ================
442468
/**
443469
* Returns a logger with execution context information for replay-aware logging.
@@ -474,7 +500,7 @@ public void close() {
474500
* prefixed with the parent hashed contextId (e.g. "<hash>-1", "<hash>-2" inside parent context <hash>). This
475501
* matches the Python SDK's stepPrefix convention and prevents ID collisions in checkpoint batches.
476502
*/
477-
private String nextOperationId() {
503+
public String nextOperationId() {
478504
var counter = String.valueOf(operationCounter.incrementAndGet());
479505
var rawId = getContextId() != null ? getContextId() + "-" + counter : counter;
480506
try {
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable;
4+
5+
import java.util.function.Function;
6+
7+
public interface DurableParallelFuture extends DurableFuture<ParallelResult> {
8+
<T> DurableFuture<T> branch(
9+
String name, TypeToken<T> resultType, Function<DurableContext, T> func, ParallelBranchConfig config);
10+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable;
4+
5+
@FunctionalInterface
6+
public interface MapFunction<I, O> {
7+
O apply(DurableContext context, I item, int index);
8+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable;
4+
5+
public class ParallelBranchConfig {
6+
// SerDes and etc
7+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable;
4+
5+
/** Statistics of a parallel operation (succeeded, failed, etc.) */
6+
public class ParallelResult {}

sdk/src/main/java/software/amazon/lambda/durable/model/OperationSubType.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
public enum OperationSubType {
1212
RUN_IN_CHILD_CONTEXT("RunInChildContext"),
1313
MAP("Map"),
14+
MAP_ITERATION("MapInteration"),
1415
PARALLEL("Parallel"),
16+
PARALLEL_BRANCH("ParallelBranch"),
1517
WAIT_FOR_CALLBACK("WaitForCallback");
1618

1719
private final String value;
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.lambda.durable.operation;
4+
5+
import java.util.ArrayList;
6+
import java.util.Queue;
7+
import java.util.concurrent.ConcurrentLinkedQueue;
8+
import java.util.concurrent.atomic.AtomicInteger;
9+
import java.util.function.Function;
10+
import software.amazon.awssdk.services.lambda.model.OperationAction;
11+
import software.amazon.awssdk.services.lambda.model.OperationType;
12+
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
13+
import software.amazon.lambda.durable.ConcurrencyConfig;
14+
import software.amazon.lambda.durable.DurableContext;
15+
import software.amazon.lambda.durable.TypeToken;
16+
import software.amazon.lambda.durable.model.OperationSubType;
17+
import software.amazon.lambda.durable.serde.NoopSerDes;
18+
import software.amazon.lambda.durable.serde.SerDes;
19+
20+
public abstract class BaseConcurrentOperation<R> extends BaseDurableOperation<R> {
21+
22+
private final ArrayList<ChildContextOperation<?>> branches;
23+
private final Queue<ChildContextOperation<?>> queue;
24+
private final DurableContext rootContext;
25+
private final AtomicInteger succeeded;
26+
private final AtomicInteger failed;
27+
private final OperationSubType subType;
28+
private final ConcurrencyConfig config;
29+
private final AtomicInteger activeBranches;
30+
31+
public BaseConcurrentOperation(
32+
String operationId,
33+
String name,
34+
OperationSubType subType,
35+
ConcurrencyConfig config,
36+
DurableContext durableContext) {
37+
super(operationId, name, OperationType.CONTEXT, new TypeToken<>() {}, new NoopSerDes(), durableContext);
38+
this.branches = new ArrayList<>();
39+
this.queue = new ConcurrentLinkedQueue<>();
40+
this.rootContext = durableContext.createChildContext(operationId, name);
41+
this.config = config;
42+
this.succeeded = new AtomicInteger(0);
43+
this.failed = new AtomicInteger(0);
44+
this.subType = subType;
45+
this.activeBranches = new AtomicInteger(0);
46+
}
47+
48+
protected <T> ChildContextOperation<T> branchInternal(
49+
String name, TypeToken<T> resultType, SerDes resultSerDes, Function<DurableContext, T> func) {
50+
var operationId = this.rootContext.nextOperationId();
51+
ChildContextOperation<T> operation;
52+
53+
synchronized (this.branches) {
54+
operation = new ChildContextOperation<>(
55+
operationId,
56+
name,
57+
func,
58+
OperationSubType.PARALLEL_BRANCH,
59+
resultType,
60+
resultSerDes,
61+
rootContext,
62+
this);
63+
branches.add(operation);
64+
queue.add(operation);
65+
}
66+
67+
executeNewBranchIfConcurrencyAllows();
68+
69+
return operation;
70+
}
71+
72+
private void executeNewBranchIfConcurrencyAllows() {
73+
synchronized (this) {
74+
// use one extra thread from user's thread pool to wait for the semaphore
75+
if (activeBranches.get() < config.maxConcurrency()) {
76+
if (!queue.isEmpty()) {
77+
activeBranches.incrementAndGet();
78+
79+
var op = queue.poll();
80+
op.execute();
81+
}
82+
}
83+
}
84+
}
85+
86+
@Override
87+
public <T> void onChildContextComplete(ChildContextOperation<T> parallelBranchOperation) {
88+
if (isOperationCompleted()) {
89+
return;
90+
}
91+
92+
activeBranches.decrementAndGet();
93+
94+
// handle branch results
95+
try {
96+
parallelBranchOperation.get();
97+
succeeded.incrementAndGet();
98+
} catch (Exception e) {
99+
failed.incrementAndGet();
100+
}
101+
102+
if (isDone()) {
103+
sendOperationUpdateAsync(OperationUpdate.builder()
104+
.action(OperationAction.SUCCEED)
105+
.subType(OperationSubType.PARALLEL.getValue())
106+
.payload(""));
107+
108+
rootContext.close();
109+
} else {
110+
// we must make sure the thread for the new branch is registered before the child thread is deregistered
111+
executeNewBranchIfConcurrencyAllows();
112+
}
113+
}
114+
115+
private boolean isDone() {
116+
return succeeded.get() >= config.minSuccessful() || failed.get() > config.toleratedFailureCount();
117+
}
118+
}

0 commit comments

Comments
 (0)