Skip to content

Commit 6e4f56d

Browse files
committed
fix concurrency operation
1 parent 7f2d2fc commit 6e4f56d

7 files changed

Lines changed: 71 additions & 33 deletions

File tree

sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,8 +1098,7 @@ void testMapWithMinSuccessful_replayLargePayloadResultConsistency(NestingType ne
10981098
if (initialResult.get() == null) {
10991099
initialResult.set(result);
11001100
} else {
1101-
// todo: this test would fail because 5th branch is skipped when replay
1102-
// assertEquals(initialResult.get(), result);
1101+
assertEquals(initialResult.get(), result);
11031102
}
11041103

11051104
return "done";

sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import java.util.List;
1010
import java.util.concurrent.atomic.AtomicInteger;
1111
import java.util.concurrent.atomic.AtomicReference;
12-
1312
import org.junit.jupiter.params.ParameterizedTest;
1413
import org.junit.jupiter.params.provider.CsvSource;
1514
import software.amazon.lambda.durable.config.CompletionConfig;
@@ -121,7 +120,7 @@ void testParallelPartialFailure_failedBranchDoesNotPreventOthers(NestingType nes
121120
void testParallelAllBranchesFail(NestingType nestingType, int events) {
122121
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
123122
var config = ParallelConfig.builder().nestingType(nestingType).build();
124-
var parallel = context.parallel("all-fail", config);
123+
ParallelDurableFuture parallel = context.parallel("all-fail", config);
125124

126125
try (parallel) {
127126
parallel.branch("branch-x", String.class, ctx -> {
@@ -158,7 +157,7 @@ void testParallelWithMaxConcurrency1_sequentialExecution(NestingType nestingType
158157
.nestingType(nestingType)
159158
.build();
160159
var futures = new ArrayList<DurableFuture<String>>();
161-
var parallel = context.parallel("sequential-parallel", config);
160+
ParallelDurableFuture parallel = context.parallel("sequential-parallel", config);
162161

163162
try (parallel) {
164163
for (var item : List.of("a", "b", "c", "d")) {
@@ -197,7 +196,7 @@ void testParallelWithMaxConcurrency2_limitedConcurrency(NestingType nestingType,
197196
.nestingType(nestingType)
198197
.build();
199198
var futures = new ArrayList<DurableFuture<String>>();
200-
var parallel = context.parallel("limited-parallel", config);
199+
ParallelDurableFuture parallel = context.parallel("limited-parallel", config);
201200

202201
try (parallel) {
203202
for (var item : List.of("a", "b", "c", "d", "e")) {
@@ -232,7 +231,7 @@ void testParallelReplayAfterInterruption_cachedResultsUsed(NestingType nestingTy
232231
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
233232
var config = ParallelConfig.builder().nestingType(nestingType).build();
234233
var futures = new ArrayList<DurableFuture<String>>();
235-
var parallel = context.parallel("replay-parallel", config);
234+
ParallelDurableFuture parallel = context.parallel("replay-parallel", config);
236235

237236
try (parallel) {
238237
for (var item : List.of("a", "b", "c")) {
@@ -534,7 +533,7 @@ void testParallelUnlimitedConcurrencyWithToleratedFailureCount(NestingType nesti
534533
void testParallelBranchesReturnDifferentTypes(NestingType nestingType, int events) {
535534
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
536535
var config = ParallelConfig.builder().nestingType(nestingType).build();
537-
var parallel = context.parallel("mixed-types", config);
536+
ParallelDurableFuture parallel = context.parallel("mixed-types", config);
538537

539538
DurableFuture<String> strFuture;
540539
DurableFuture<Integer> intFuture;
@@ -658,7 +657,7 @@ void testParallel50BranchesWithWaitForCallback_maxConcurrency5(NestingType nesti
658657
.maxConcurrency(5)
659658
.nestingType(nestingType)
660659
.build();
661-
var parallel = context.parallel("50-callbacks-limited", config);
660+
ParallelDurableFuture parallel = context.parallel("50-callbacks-limited", config);
662661

663662
try (parallel) {
664663
for (int i = 0; i < branchCount; i++) {
@@ -708,7 +707,7 @@ void testParallel50BranchesWithWaitForCallback_partialFailure(NestingType nestin
708707

709708
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
710709
var config = ParallelConfig.builder().nestingType(nestingType).build();
711-
var parallel = context.parallel("50-callbacks-partial-fail", config);
710+
ParallelDurableFuture parallel = context.parallel("50-callbacks-partial-fail", config);
712711

713712
try (parallel) {
714713
for (int i = 0; i < branchCount; i++) {
@@ -862,7 +861,7 @@ void testParallel50BranchesWithWaitForCondition_someExceedMaxAttempts(NestingTyp
862861

863862
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
864863
var config = ParallelConfig.builder().nestingType(nestingType).build();
865-
var parallel = context.parallel("50-conditions-some-fail", config);
864+
ParallelDurableFuture parallel = context.parallel("50-conditions-some-fail", config);
866865

867866
try (parallel) {
868867
for (int i = 0; i < branchCount; i++) {
@@ -913,7 +912,7 @@ void testParallel50BranchesWithWaitForCondition_replay(NestingType nestingType,
913912

914913
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
915914
var config = ParallelConfig.builder().nestingType(nestingType).build();
916-
var parallel = context.parallel("50-conditions-replay", config);
915+
ParallelDurableFuture parallel = context.parallel("50-conditions-replay", config);
917916

918917
try (parallel) {
919918
for (int i = 0; i < branchCount; i++) {
@@ -963,7 +962,7 @@ void testParallel50BranchesMixed_callbackAndCondition(NestingType nestingType, i
963962

964963
var runner = LocalDurableTestRunner.create(String.class, (input, context) -> {
965964
var config = ParallelConfig.builder().nestingType(nestingType).build();
966-
var parallel = context.parallel("50-mixed", config);
965+
ParallelDurableFuture parallel = context.parallel("50-mixed", config);
967966

968967
try (parallel) {
969968
for (int i = 0; i < branchCount; i++) {
@@ -1128,8 +1127,7 @@ void testParallelWithMinSuccessful_earlyTermination_consistentResult(NestingType
11281127
if (initialResult.get() == null) {
11291128
initialResult.set(result);
11301129
} else {
1131-
//todo: fix this
1132-
// assertEquals(initialResult.get(), result);
1130+
assertEquals(initialResult.get(), result);
11331131
}
11341132
assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionStatus());
11351133
assertTrue(result.completionStatus().isSucceeded());

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ public abstract class BaseDurableOperation {
5151
protected final CompletableFuture<BaseDurableOperation> completionFuture;
5252
protected final BaseDurableOperation parentOperation;
5353
protected final boolean isVirtual;
54+
protected final AtomicBoolean replayCompletedOperation = new AtomicBoolean(false);
5455
private final DurableContextImpl durableContext;
5556
private final AtomicReference<CompletableFuture<Void>> runningUserHandler = new AtomicReference<>(null);
56-
private final AtomicBoolean replayCompletedOperation = new AtomicBoolean(false);
5757

5858
protected BaseDurableOperation(
5959
OperationIdentifier operationIdentifier,

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
*/
5353
public abstract class ConcurrencyOperation<T> extends SerializableDurableOperation<T> {
5454

55+
protected record ExpectedCompletionStatus(int completed, ConcurrencyCompletionStatus completionStatus) {}
56+
5557
private static final Logger logger = LoggerFactory.getLogger(ConcurrencyOperation.class);
5658

5759
private final int maxConcurrency;
@@ -129,9 +131,9 @@ protected <R> ChildContextOperation<R> createItem(
129131
// ========== Concurrency control ==========
130132

131133
/**
132-
* Creates and enqueues an item without starting execution. Use {@link #executeItems()} to begin execution after all
133-
* items have been enqueued. This prevents early termination from blocking item creation when all items are known
134-
* upfront (e.g., map operations).
134+
* Creates and enqueues an item without starting execution. Use {@link #executeItems(ExpectedCompletionStatus)} to
135+
* begin execution after all items have been enqueued. This prevents early termination from blocking item creation
136+
* when all items are known upfront (e.g., map operations).
135137
*/
136138
protected <R> ChildContextOperation<R> enqueueItem(
137139
String name,
@@ -160,6 +162,11 @@ private void notifyConsumerThread() {
160162

161163
/** Starts execution of all enqueued items. */
162164
protected void executeItems() {
165+
executeItems(null);
166+
}
167+
168+
/** Starts execution of all enqueued items until the expectedCompletionStatus is met. */
169+
protected void executeItems(ExpectedCompletionStatus expectedCompletionStatus) {
163170
// variables accessed only by the consumer thread. Put them here to avoid accidentally used by other threads
164171
Set<BaseDurableOperation> runningChildren = new HashSet<>();
165172
AtomicInteger succeededCount = new AtomicInteger(0);
@@ -182,7 +189,8 @@ protected void executeItems() {
182189
if (isOperationCompleted()) {
183190
return;
184191
}
185-
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
192+
var completionStatus =
193+
canComplete(succeededCount, failedCount, runningChildren, expectedCompletionStatus);
186194
if (completionStatus != null) {
187195
handleCompletion(completionStatus);
188196
return;
@@ -198,7 +206,8 @@ protected void executeItems() {
198206

199207
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
200208
// immediately return null and repeat the above again
201-
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
209+
var child = waitForChildCompletion(
210+
succeededCount, failedCount, runningChildren, expectedCompletionStatus);
202211

203212
// child may be null if the consumer thread is woken up due to new items added or completion
204213
// condition
@@ -235,7 +244,10 @@ private void handleException(Throwable ex) {
235244
}
236245

237246
private BaseDurableOperation waitForChildCompletion(
238-
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
247+
AtomicInteger succeededCount,
248+
AtomicInteger failedCount,
249+
Set<BaseDurableOperation> runningChildren,
250+
ExpectedCompletionStatus expectedCompletionStatus) {
239251
var threadContext = getCurrentThreadContext();
240252
CompletableFuture<Object> future;
241253

@@ -244,7 +256,7 @@ private BaseDurableOperation waitForChildCompletion(
244256
if (isOperationCompleted()) {
245257
return null;
246258
}
247-
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
259+
var completionStatus = canComplete(succeededCount, failedCount, runningChildren, expectedCompletionStatus);
248260
if (completionStatus != null) {
249261
return null;
250262
}
@@ -305,10 +317,22 @@ private void onItemComplete(
305317
* @return the completion status if the operation is complete, or null if it should continue
306318
*/
307319
private ConcurrencyCompletionStatus canComplete(
308-
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
320+
AtomicInteger succeededCount,
321+
AtomicInteger failedCount,
322+
Set<BaseDurableOperation> runningChildren,
323+
ExpectedCompletionStatus expectedCompletionStatus) {
309324
int succeeded = succeededCount.get();
310325
int failed = failedCount.get();
311326

327+
if (expectedCompletionStatus != null) {
328+
if (succeeded + failed >= expectedCompletionStatus.completed) {
329+
return expectedCompletionStatus.completionStatus;
330+
}
331+
332+
// if expected completion status is not null, we always complete all the children previously completed
333+
return null;
334+
}
335+
312336
// If we've met the minimum successful count, we're done
313337
if (minSuccessful != null && succeeded >= minSuccessful) {
314338
return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED;

sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ protected void replay(Operation existing) {
152152
}
153153
if (Boolean.TRUE.equals(existing.contextDetails().replayChildren())) {
154154
// Large result: re-execute children to reconstruct MapResult
155-
executeItems();
155+
var expected = new ExpectedCompletionStatus(
156+
deserializedResult.succeeded().size()
157+
+ deserializedResult.failed().size(),
158+
deserializedResult.completionReason());
159+
executeItems(expected);
156160
} else {
157161
// Small result: MapResult is in the payload, skip child replay
158162
cachedResult = deserializedResult;
@@ -173,7 +177,7 @@ protected void replay(Operation existing) {
173177

174178
@Override
175179
protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletionStatus) {
176-
this.cachedResult = constructMapResult(concurrencyCompletionStatus, false);
180+
this.cachedResult = constructMapResult(concurrencyCompletionStatus);
177181
var serialized = serializeResult(cachedResult);
178182
var serializedBytes = serialized.getBytes(StandardCharsets.UTF_8);
179183

@@ -184,7 +188,7 @@ protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletio
184188
.payload(serialized));
185189
} else {
186190
// Large result: checkpoint with stripped payload + replayChildren flag
187-
var strippedResult = serializeResult(constructMapResult(concurrencyCompletionStatus, true));
191+
var strippedResult = serializeResult(stripMapResult(cachedResult));
188192
sendOperationUpdate(OperationUpdate.builder()
189193
.action(OperationAction.SUCCEED)
190194
.subType(getSubType().getValue())
@@ -194,9 +198,16 @@ protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletio
194198
}
195199
}
196200

201+
private MapResult<O> stripMapResult(MapResult<O> result) {
202+
return new MapResult<>(
203+
result.items().stream()
204+
.map(item -> new MapResult.MapResultItem<O>(item.status(), null, null))
205+
.toList(),
206+
result.completionReason());
207+
}
208+
197209
@SuppressWarnings("unchecked")
198-
private MapResult<O> constructMapResult(
199-
ConcurrencyCompletionStatus concurrencyCompletionStatus, boolean stripResult) {
210+
private MapResult<O> constructMapResult(ConcurrencyCompletionStatus concurrencyCompletionStatus) {
200211
var children = getBranches();
201212
var resultItems = new ArrayList<MapResult.MapResultItem<O>>(Collections.nCopies(items.size(), null));
202213

@@ -206,7 +217,7 @@ private MapResult<O> constructMapResult(
206217
resultItems.set(i, MapResult.MapResultItem.skipped());
207218
} else {
208219
try {
209-
resultItems.set(i, MapResult.MapResultItem.succeeded(stripResult ? null : branch.get()));
220+
resultItems.set(i, MapResult.MapResultItem.succeeded(branch.get()));
210221
} catch (Throwable exception) {
211222
Throwable throwable = ExceptionHelper.unwrapCompletableFuture(exception);
212223
if (throwable instanceof SuspendExecutionException suspendExecutionException) {
@@ -218,8 +229,7 @@ private MapResult<O> constructMapResult(
218229
// terminate the execution and throw the exception if it's not recoverable
219230
throw terminateExecution(unrecoverableDurableExecutionException);
220231
}
221-
resultItems.set(
222-
i, MapResult.MapResultItem.failed(stripResult ? null : MapResult.MapError.of(throwable)));
232+
resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(throwable)));
223233
}
224234
}
225235
}

sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ protected void replay(Operation existing) {
126126
partialResult = existing.contextDetails() != null
127127
? deserializeResult(existing.contextDetails().result())
128128
: null;
129+
if (partialResult != null) {
130+
var expected = new ExpectedCompletionStatus(
131+
partialResult.succeeded() + partialResult.failed(), partialResult.completionStatus());
132+
executeItems(expected);
133+
return;
134+
}
129135
}
130136
executeItems();
131137
}

sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ void minSuccessful_notExecuteSkippedBranchWhenReplay() {
269269
.subType(OperationSubType.PARALLEL.getValue())
270270
.status(OperationStatus.SUCCEEDED)
271271
.contextDetails(ContextDetails.builder()
272-
.result("{\"statuses\":[\"SKIPPED\", \"SUCCEEDED\"]}")
272+
.result(
273+
"{\"succeeded\": 1, \"completionStatus\": \"MIN_SUCCESSFUL_REACHED\", \"statuses\":[\"SKIPPED\", \"SUCCEEDED\"]}")
273274
.build())
274275
.build());
275276
when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2))

0 commit comments

Comments
 (0)