-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathParallelOperation.java
More file actions
176 lines (159 loc) · 7.57 KB
/
Copy pathParallelOperation.java
File metadata and controls
176 lines (159 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.operation;
import java.util.List;
import java.util.function.Function;
import software.amazon.awssdk.services.lambda.model.ContextOptions;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationAction;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.DurableContext;
import software.amazon.lambda.durable.DurableFuture;
import software.amazon.lambda.durable.ParallelDurableFuture;
import software.amazon.lambda.durable.TypeToken;
import software.amazon.lambda.durable.config.ParallelBranchConfig;
import software.amazon.lambda.durable.config.ParallelConfig;
import software.amazon.lambda.durable.context.DurableContextImpl;
import software.amazon.lambda.durable.execution.ExecutionManager;
import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus;
import software.amazon.lambda.durable.model.OperationIdentifier;
import software.amazon.lambda.durable.model.OperationSubType;
import software.amazon.lambda.durable.model.ParallelResult;
import software.amazon.lambda.durable.serde.SerDes;
/**
* Manages parallel execution of multiple branches as child context operations.
*
* <p>Extends {@link ConcurrencyOperation} to provide parallel-specific behavior:
*
* <ul>
* <li>Creates branches as {@link ChildContextOperation} with {@link OperationSubType#PARALLEL_BRANCH}
* <li>Checkpoints SUCCESS on the parallel context when completion criteria are met
* <li>Returns a {@link ParallelResult} summarising branch outcomes
* </ul>
*
* <p>Context hierarchy:
*
* <pre>
* DurableContext (root)
* └── ParallelOperation context (ChildContextOperation with PARALLEL subtype)
* ├── Branch 1 context (ChildContextOperation with PARALLEL_BRANCH)
* ├── Branch 2 context (ChildContextOperation with PARALLEL_BRANCH)
* └── Branch N context (ChildContextOperation with PARALLEL_BRANCH)
* </pre>
*/
public class ParallelOperation extends ConcurrencyOperation<ParallelResult> implements ParallelDurableFuture {
// this field could be written and read in different threads
private volatile ParallelResult cachedResult;
private volatile ParallelResult partialResult;
public ParallelOperation(
OperationIdentifier operationIdentifier,
SerDes resultSerDes,
DurableContextImpl durableContext,
ParallelConfig config) {
super(
operationIdentifier,
TypeToken.get(ParallelResult.class),
resultSerDes,
durableContext,
config.maxConcurrency(),
config.completionConfig().minSuccessful(),
config.completionConfig().toleratedFailureCount(),
config.nestingType());
}
@Override
protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletionStatus) {
var items = List.copyOf(getBranches());
var statuses = items.stream().map(this::getParallelItemStatus).toList();
int succeededCount = Math.toIntExact(statuses.stream()
.filter(s -> s == ParallelResult.Status.SUCCEEDED)
.count());
int failedCount = Math.toIntExact(
statuses.stream().filter(s -> s == ParallelResult.Status.FAILED).count());
int skippedCount = items.size() - succeededCount - failedCount;
cachedResult = new ParallelResult(
items.size(), succeededCount, failedCount, skippedCount, concurrencyCompletionStatus, statuses);
// Branches added after checkpoint will not exist in the checkpointed result, but they'll be in the returned
// value from get() method.
sendOperationUpdate(OperationUpdate.builder()
.action(OperationAction.SUCCEED)
.subType(getSubType().getValue())
.payload(serializeResult(cachedResult))
.contextOptions(ContextOptions.builder().replayChildren(true).build()));
}
private ParallelResult.Status getParallelItemStatus(ChildContextOperation<?> childContextOperation) {
if (!childContextOperation.isOperationCompleted()) {
return ParallelResult.Status.SKIPPED;
}
try {
childContextOperation.get();
return ParallelResult.Status.SUCCEEDED;
} catch (Throwable t) {
return ParallelResult.Status.FAILED;
}
}
private ParallelResult rebuildParallelResult() {
if (cachedResult != null && cachedResult.size() != getBranches().size()) {
return new ParallelResult(
getBranches().size(), // size might be updated after cached result is built
cachedResult.succeeded(),
cachedResult.failed(),
cachedResult.skipped(),
cachedResult.completionStatus(),
cachedResult.statuses());
}
return cachedResult;
}
@Override
protected void start() {
sendOperationUpdateAsync(OperationUpdate.builder()
.action(OperationAction.START)
.subType(getSubType().getValue()));
executeItems();
}
@Override
protected void replay(Operation existing) {
// No-op: child branches handle their own replay via ChildContextOperation.replay().
// Set replaying=true so handleSuccess() skips re-checkpointing the already-completed parallel context.
if (ExecutionManager.isTerminalStatus(existing.status())) {
// the operation is already completed, extract the branch completion status from the partialResult
partialResult = existing.contextDetails() != null
? deserializeResult(existing.contextDetails().result())
: null;
if (partialResult != null) {
var expected = new ExpectedCompletionStatus(
partialResult.succeeded() + partialResult.failed(), partialResult.completionStatus());
executeItems(expected);
return;
}
}
executeItems();
}
@Override
public ParallelResult get() {
join();
return rebuildParallelResult();
}
/** Calls {@link #get()} if not already called. Guarantees that the context is closed. */
@Override
public void close() {
if (isJoined.get()) {
return;
}
join();
}
public <T> DurableFuture<T> branch(
String name, TypeToken<T> resultType, Function<DurableContext, T> func, ParallelBranchConfig config) {
if (isJoined.get()) {
throw new IllegalStateException("Cannot add branches after join() has been called");
}
var nextBranchIndex = getBranches().size();
// ConcurrencyOperation will skip this branch if skip=true:
// 1. if the parallel operation is already completed (partialResult is not null)
// 2. if the branch is already skipped in the partialResult or nonexistent in the partialResult
var skip = partialResult != null
&& (partialResult.statuses().size() <= nextBranchIndex
|| partialResult.statuses().get(nextBranchIndex) == ParallelResult.Status.SKIPPED);
var serDes = config.serDes() == null ? getContext().getDurableConfig().getSerDes() : config.serDes();
return enqueueItem(name, func, resultType, serDes, OperationSubType.PARALLEL_BRANCH, skip);
}
}