Skip to content

Commit 5fae79d

Browse files
authored
[fix]: fix ConcurrentModificationException when completing invocations (#362)
* fix ConcurrentModificationException when completing invocations * fix non-existent branches in checkpointed result
1 parent bbaa0d0 commit 5fae79d

3 files changed

Lines changed: 16 additions & 6 deletions

File tree

sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import java.time.Instant;
66
import java.util.ArrayList;
77
import java.util.Collections;
8-
import java.util.HashMap;
98
import java.util.HashSet;
109
import java.util.List;
1110
import java.util.Map;
1211
import java.util.Objects;
1312
import java.util.Set;
1413
import java.util.concurrent.CancellationException;
1514
import java.util.concurrent.CompletableFuture;
15+
import java.util.concurrent.ConcurrentHashMap;
1616
import java.util.concurrent.ThreadPoolExecutor;
1717
import java.util.concurrent.atomic.AtomicReference;
1818
import java.util.stream.Collectors;
@@ -59,7 +59,7 @@ public class ExecutionManager implements AutoCloseable {
5959
private final DurableConfig durableConfig;
6060

6161
// ===== Thread Coordination =====
62-
private final Map<String, BaseDurableOperation> registeredOperations = Collections.synchronizedMap(new HashMap<>());
62+
private final Map<String, BaseDurableOperation> registeredOperations = new ConcurrentHashMap<>();
6363
private final Set<String> activeThreads = Collections.synchronizedSet(new HashSet<>());
6464
private static final ThreadLocal<ThreadContext> currentThreadContext = new ThreadLocal<>();
6565
private final CompletableFuture<Void> executionExceptionFuture = new CompletableFuture<>();

sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-License-Identifier: Apache-2.0
33
package software.amazon.lambda.durable.operation;
44

5+
import java.util.List;
56
import java.util.function.Function;
67
import software.amazon.awssdk.services.lambda.model.ContextOptions;
78
import software.amazon.awssdk.services.lambda.model.Operation;
@@ -66,7 +67,8 @@ public ParallelOperation(
6667

6768
@Override
6869
protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletionStatus) {
69-
var items = getBranches();
70+
71+
var items = List.copyOf(getBranches());
7072
var statuses = items.stream().map(this::getParallelItemStatus).toList();
7173
int succeededCount = Math.toIntExact(statuses.stream()
7274
.filter(s -> s == ParallelResult.Status.SUCCEEDED)
@@ -76,6 +78,9 @@ protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletio
7678
int skippedCount = items.size() - succeededCount - failedCount;
7779
cachedResult = new ParallelResult(
7880
items.size(), succeededCount, failedCount, skippedCount, concurrencyCompletionStatus, statuses);
81+
82+
// Branches added after checkpoint will not exist in the checkpointed result, but they'll be in the returned
83+
// value from get() method.
7984
sendOperationUpdate(OperationUpdate.builder()
8085
.action(OperationAction.SUCCEED)
8186
.subType(getSubType().getValue())
@@ -157,9 +162,14 @@ public <T> DurableFuture<T> branch(
157162
throw new IllegalStateException("Cannot add branches after join() has been called");
158163
}
159164

160-
// ConcurrencyOperation will skip this branch if skip=true
165+
var nextBranchIndex = getBranches().size();
166+
167+
// ConcurrencyOperation will skip this branch if skip=true:
168+
// 1. if the parallel operation is already completed (partialResult is not null)
169+
// 2. if the branch is already skipped in the partialResult or nonexistent in the partialResult
161170
var skip = partialResult != null
162-
&& partialResult.statuses().get(getBranches().size()) == ParallelResult.Status.SKIPPED;
171+
&& (partialResult.statuses().size() <= nextBranchIndex
172+
|| partialResult.statuses().get(nextBranchIndex) == ParallelResult.Status.SKIPPED);
163173
var serDes = config.serDes() == null ? getContext().getDurableConfig().getSerDes() : config.serDes();
164174
return enqueueItem(name, func, resultType, serDes, OperationSubType.PARALLEL_BRANCH, skip);
165175
}

sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ void minSuccessful_notExecuteSkippedBranchWhenReplay() {
270270
.status(OperationStatus.SUCCEEDED)
271271
.contextDetails(ContextDetails.builder()
272272
.result(
273-
"{\"succeeded\": 1, \"completionStatus\": \"MIN_SUCCESSFUL_REACHED\", \"statuses\":[\"SKIPPED\", \"SUCCEEDED\"]}")
273+
"{\"size\": 2, \"skipped\": 1, \"succeeded\": 1, \"completionStatus\": \"MIN_SUCCESSFUL_REACHED\", \"statuses\":[\"SKIPPED\", \"SUCCEEDED\"]}")
274274
.build())
275275
.build());
276276
when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2))

0 commit comments

Comments
 (0)