Skip to content

Commit eae00f3

Browse files
Coop suspension WIP
1 parent 69e1507 commit eae00f3

13 files changed

Lines changed: 522 additions & 88 deletions

File tree

sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dev.restate.sdk.core.statemachine.InvocationState;
1919
import dev.restate.sdk.core.statemachine.NotificationValue;
2020
import dev.restate.sdk.core.statemachine.StateMachine;
21+
import dev.restate.sdk.core.statemachine.UnresolvedFuture;
2122
import dev.restate.sdk.endpoint.definition.AsyncResult;
2223
import dev.restate.sdk.endpoint.definition.HandlerType;
2324
import dev.restate.sdk.endpoint.definition.ServiceType;
@@ -28,7 +29,6 @@
2829
import java.util.concurrent.CompletableFuture;
2930
import java.util.concurrent.Executor;
3031
import java.util.function.Consumer;
31-
import java.util.stream.Stream;
3232
import org.apache.logging.log4j.LogManager;
3333
import org.apache.logging.log4j.Logger;
3434
import org.jspecify.annotations.Nullable;
@@ -422,33 +422,39 @@ private void pollAsyncResultInner(AsyncResultInternal<?> asyncResult) {
422422
asyncResult.tryComplete(this.stateMachine);
423423

424424
// Now let's take the unprocessed leaves
425-
List<Integer> uncompletedLeaves =
426-
Stream.concat(asyncResult.uncompletedLeaves(), Stream.of(CANCEL_HANDLE)).toList();
427-
if (uncompletedLeaves.size() == 1) {
425+
List<Integer> uncompletedLeaves = asyncResult.uncompletedLeaves().toList();
426+
if (uncompletedLeaves.isEmpty()) {
428427
// Nothing else to do!
429428
return;
430429
}
431430

431+
// Build the UnresolvedFuture from the leaf handles
432+
UnresolvedFuture future =
433+
uncompletedLeaves.size() == 1
434+
? new UnresolvedFuture.Single(uncompletedLeaves.get(0))
435+
: new UnresolvedFuture.FirstCompleted(
436+
uncompletedLeaves.stream()
437+
.map(h -> (UnresolvedFuture) new UnresolvedFuture.Single(h))
438+
.toList());
439+
432440
// Not ready yet, let's try to do some progress
433-
StateMachine.DoProgressResponse response;
441+
StateMachine.AwaitResponse response;
434442
try {
435-
response = this.stateMachine.doProgress(uncompletedLeaves);
443+
response = this.stateMachine.doAwait(future);
436444
} catch (Throwable e) {
437445
this.failWithoutContextSwitch(e);
438446
asyncResult.publicFuture().completeExceptionally(AbortedExecutionException.INSTANCE);
439447
return;
440448
}
441449

442-
if (response instanceof StateMachine.DoProgressResponse.AnyCompleted) {
450+
if (response instanceof StateMachine.AwaitResponse.AnyCompleted) {
443451
// Let it loop now
444-
} else if (response instanceof StateMachine.DoProgressResponse.ReadFromInput
445-
|| response instanceof StateMachine.DoProgressResponse.WaitingPendingRun) {
452+
} else if (response instanceof StateMachine.AwaitResponse.WaitingExternalProgress wep) {
446453
this.stateMachine.onNextEvent(
447-
() -> this.pollAsyncResultInner(asyncResult),
448-
response instanceof StateMachine.DoProgressResponse.ReadFromInput);
454+
() -> this.pollAsyncResultInner(asyncResult), wep.waitingInput());
449455
return;
450-
} else if (response instanceof StateMachine.DoProgressResponse.ExecuteRun) {
451-
triggerScheduledRun(((StateMachine.DoProgressResponse.ExecuteRun) response).handle());
456+
} else if (response instanceof StateMachine.AwaitResponse.ExecuteRun) {
457+
triggerScheduledRun(((StateMachine.AwaitResponse.ExecuteRun) response).handle());
452458
// Let it loop now
453459
}
454460
}

sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java

Lines changed: 297 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,21 +74,198 @@ public int createHandleMapping(NotificationId notificationId) {
7474
return assignedHandle;
7575
}
7676

77-
public boolean processNextUntilAnyFound(Set<NotificationId> ids) {
78-
while (!toProcess.isEmpty()) {
79-
Map.Entry<NotificationId, NotificationValue> notif = toProcess.removeFirst();
80-
boolean anyFound = ids.contains(notif.getKey());
81-
ready.put(notif.getKey(), notif.getValue());
82-
if (anyFound) {
83-
return true;
77+
/**
78+
* Try to resolve the given future against available notifications.
79+
*
80+
* <p>Operates on a deep-mutable copy of {@code unresolved} so the caller's object is unchanged.
81+
*
82+
* @return {@link ResolveFutureResult.AnyCompleted} if the future can be resolved, or {@link
83+
* ResolveFutureResult.WaitExternalInput} with the remaining (reduced) unresolved future if
84+
* not.
85+
*/
86+
public ResolveFutureResult tryResolveFuture(UnresolvedFuture unresolved) {
87+
// Work on a mutable copy so we can prune completed children in place.
88+
unresolved = deepMutableCopy(unresolved);
89+
90+
while (true) {
91+
TryResolveResult result = tryResolveFutureInternal(unresolved);
92+
93+
if (result == TryResolveResult.SHORT_CIRCUITED || result.handleState().isCompleted()) {
94+
return ResolveFutureResult.ANY_COMPLETED;
95+
}
96+
97+
// Not completed yet — try popping the next queued notification and retry
98+
if (!popNotificationQueue()) {
99+
return new ResolveFutureResult.WaitExternalInput(unresolved);
84100
}
85101
}
86-
return false;
87102
}
88103

104+
/** Create a deep copy of a future tree with all children stored in mutable {@link ArrayList}s. */
105+
private static UnresolvedFuture deepMutableCopy(UnresolvedFuture fut) {
106+
if (fut instanceof UnresolvedFuture.Single) {
107+
return fut;
108+
} else if (fut instanceof UnresolvedFuture.FirstCompleted fc) {
109+
var copy = new ArrayList<UnresolvedFuture>(fc.children().size());
110+
for (var c : fc.children()) copy.add(deepMutableCopy(c));
111+
return new UnresolvedFuture.FirstCompleted(copy);
112+
} else if (fut instanceof UnresolvedFuture.AllCompleted ac) {
113+
var copy = new ArrayList<UnresolvedFuture>(ac.children().size());
114+
for (var c : ac.children()) copy.add(deepMutableCopy(c));
115+
return new UnresolvedFuture.AllCompleted(copy);
116+
} else if (fut instanceof UnresolvedFuture.FirstSucceededOrAllFailed fsaf) {
117+
var copy = new ArrayList<UnresolvedFuture>(fsaf.children().size());
118+
for (var c : fsaf.children()) copy.add(deepMutableCopy(c));
119+
return new UnresolvedFuture.FirstSucceededOrAllFailed(copy);
120+
} else if (fut instanceof UnresolvedFuture.AllSucceededOrFirstFailed asff) {
121+
var copy = new ArrayList<UnresolvedFuture>(asff.children().size());
122+
for (var c : asff.children()) copy.add(deepMutableCopy(c));
123+
return new UnresolvedFuture.AllSucceededOrFirstFailed(copy);
124+
} else if (fut instanceof UnresolvedFuture.Unknown u) {
125+
var copy = new ArrayList<UnresolvedFuture>(u.children().size());
126+
for (var c : u.children()) copy.add(deepMutableCopy(c));
127+
return new UnresolvedFuture.Unknown(copy);
128+
}
129+
throw new IllegalStateException("Unknown UnresolvedFuture type: " + fut);
130+
}
131+
132+
/** Returns false if there's nothing left in toProcess. */
133+
private boolean popNotificationQueue() {
134+
Map.Entry<NotificationId, NotificationValue> notif = toProcess.pollFirst();
135+
if (notif == null) {
136+
return false;
137+
}
138+
ready.put(notif.getKey(), notif.getValue());
139+
return true;
140+
}
141+
142+
/**
143+
* Internal recursive resolution. Returns {@link TryResolveResult#SHORT_CIRCUITED} to signal early
144+
* exit (a combinator completed and wants to propagate up).
145+
*
146+
* <p>This method mutates {@code unresolved} in place when children are removed (e.g. completed
147+
* children are removed from AllCompleted lists).
148+
*/
149+
private TryResolveResult tryResolveFutureInternal(UnresolvedFuture unresolved) {
150+
if (unresolved instanceof UnresolvedFuture.Single s) {
151+
return new TryResolveResult(resolveHandleState(s.handle()));
152+
153+
} else if (unresolved instanceof UnresolvedFuture.FirstCompleted fc) {
154+
return resolveFirstCompleted(fc.children());
155+
156+
} else if (unresolved instanceof UnresolvedFuture.Unknown u) {
157+
return resolveFirstCompleted(u.children());
158+
159+
} else if (unresolved instanceof UnresolvedFuture.AllCompleted ac) {
160+
return resolveAllCompleted(ac.children());
161+
162+
} else if (unresolved instanceof UnresolvedFuture.FirstSucceededOrAllFailed fsaf) {
163+
return resolveFirstSucceededOrAllFailed(fsaf.children());
164+
165+
} else if (unresolved instanceof UnresolvedFuture.AllSucceededOrFirstFailed asff) {
166+
return resolveAllSucceededOrFirstFailed(asff.children());
167+
}
168+
169+
throw new IllegalStateException("Unknown UnresolvedFuture type: " + unresolved);
170+
}
171+
172+
/** FirstCompleted / Unknown: resolve as soon as any child completes (success or failure). */
173+
private TryResolveResult resolveFirstCompleted(List<UnresolvedFuture> children) {
174+
for (UnresolvedFuture child : children) {
175+
TryResolveResult childResult = tryResolveFutureInternal(child);
176+
if (childResult == TryResolveResult.SHORT_CIRCUITED
177+
|| childResult.handleState().isCompleted()) {
178+
children.clear();
179+
return TryResolveResult.SHORT_CIRCUITED;
180+
}
181+
}
182+
return TryResolveResult.PENDING;
183+
}
184+
185+
/** AllCompleted: wait for every child to complete (success or failure). */
186+
private TryResolveResult resolveAllCompleted(List<UnresolvedFuture> children) {
187+
var it = children.listIterator();
188+
while (it.hasNext()) {
189+
UnresolvedFuture child = it.next();
190+
TryResolveResult childResult = tryResolveFutureInternal(child);
191+
if (childResult == TryResolveResult.SHORT_CIRCUITED) {
192+
// A nested combinator short-circuited — propagate immediately
193+
return TryResolveResult.SHORT_CIRCUITED;
194+
} else if (childResult.handleState().isCompleted()) {
195+
it.remove();
196+
}
197+
}
198+
if (children.isEmpty()) {
199+
return new TryResolveResult(HandleState.SUCCEEDED);
200+
}
201+
return TryResolveResult.PENDING;
202+
}
203+
204+
/** FirstSucceededOrAllFailed: first success wins; fail only if all fail. */
205+
private TryResolveResult resolveFirstSucceededOrAllFailed(List<UnresolvedFuture> children) {
206+
var it = children.listIterator();
207+
while (it.hasNext()) {
208+
UnresolvedFuture child = it.next();
209+
TryResolveResult childResult = tryResolveFutureInternal(child);
210+
if (childResult == TryResolveResult.SHORT_CIRCUITED) {
211+
// A nested combinator short-circuited — treat as succeeded, propagate
212+
children.clear();
213+
return TryResolveResult.SHORT_CIRCUITED;
214+
}
215+
HandleState state = childResult.handleState();
216+
if (state == HandleState.SUCCEEDED) {
217+
children.clear();
218+
return TryResolveResult.SHORT_CIRCUITED;
219+
} else if (state == HandleState.FAILED) {
220+
it.remove();
221+
}
222+
}
223+
if (children.isEmpty()) {
224+
return new TryResolveResult(HandleState.FAILED);
225+
}
226+
return TryResolveResult.PENDING;
227+
}
228+
229+
/** AllSucceededOrFirstFailed: all must succeed; first failure short-circuits. */
230+
private TryResolveResult resolveAllSucceededOrFirstFailed(List<UnresolvedFuture> children) {
231+
var it = children.listIterator();
232+
while (it.hasNext()) {
233+
UnresolvedFuture child = it.next();
234+
TryResolveResult childResult = tryResolveFutureInternal(child);
235+
if (childResult == TryResolveResult.SHORT_CIRCUITED) {
236+
// A nested combinator short-circuited — propagate immediately
237+
return TryResolveResult.SHORT_CIRCUITED;
238+
}
239+
HandleState state = childResult.handleState();
240+
if (state == HandleState.FAILED) {
241+
children.clear();
242+
return TryResolveResult.SHORT_CIRCUITED;
243+
} else if (state == HandleState.SUCCEEDED) {
244+
it.remove();
245+
}
246+
}
247+
if (children.isEmpty()) {
248+
return new TryResolveResult(HandleState.SUCCEEDED);
249+
}
250+
return TryResolveResult.PENDING;
251+
}
252+
253+
private HandleState resolveHandleState(int handle) {
254+
NotificationId id = handleMapping.get(handle);
255+
if (id == null) {
256+
return HandleState.PENDING;
257+
}
258+
NotificationValue val = ready.get(id);
259+
if (val == null) {
260+
return HandleState.PENDING;
261+
}
262+
return (val instanceof NotificationValue.Failure) ? HandleState.FAILED : HandleState.SUCCEEDED;
263+
}
264+
265+
/** After {@code take_handle} the mapping is gone, so unknown handles are treated as completed. */
89266
public boolean isHandleCompleted(int handle) {
90267
NotificationId id = handleMapping.get(handle);
91-
return id != null && ready.containsKey(id);
268+
return id == null || ready.containsKey(id);
92269
}
93270

94271
public boolean nonDeterministicFindId(NotificationId id) {
@@ -128,4 +305,115 @@ public Optional<NotificationValue> takeHandle(int handle) {
128305
}
129306
return Optional.empty();
130307
}
308+
309+
public Optional<NotificationValue> copyHandle(int handle) {
310+
NotificationId id = handleMapping.get(handle);
311+
if (id == null) {
312+
return Optional.empty();
313+
}
314+
return Optional.ofNullable(ready.get(id));
315+
}
316+
317+
/**
318+
* Convert an {@link UnresolvedFuture} tree to the wire-format {@link Protocol.Future} message.
319+
* Single children are inlined into the parent's waiting_* fields; all other children become
320+
* nested Future messages.
321+
*/
322+
public Protocol.Future resolveUnresolvedFuture(UnresolvedFuture unresolved) {
323+
var builder = Protocol.Future.newBuilder();
324+
325+
if (unresolved instanceof UnresolvedFuture.Single s) {
326+
builder.setCombinatorType(Protocol.CombinatorType.FIRST_COMPLETED);
327+
pushHandle(builder, s.handle());
328+
return builder.build();
329+
}
330+
331+
List<UnresolvedFuture> children;
332+
if (unresolved instanceof UnresolvedFuture.Unknown u) {
333+
builder.setCombinatorType(Protocol.CombinatorType.COMBINATOR_UNKNOWN);
334+
children = u.children();
335+
} else if (unresolved instanceof UnresolvedFuture.FirstCompleted fc) {
336+
builder.setCombinatorType(Protocol.CombinatorType.FIRST_COMPLETED);
337+
children = fc.children();
338+
} else if (unresolved instanceof UnresolvedFuture.AllCompleted ac) {
339+
builder.setCombinatorType(Protocol.CombinatorType.ALL_COMPLETED);
340+
children = ac.children();
341+
} else if (unresolved instanceof UnresolvedFuture.FirstSucceededOrAllFailed fsaf) {
342+
builder.setCombinatorType(Protocol.CombinatorType.FIRST_SUCCEEDED_OR_ALL_FAILED);
343+
children = fsaf.children();
344+
} else if (unresolved instanceof UnresolvedFuture.AllSucceededOrFirstFailed asff) {
345+
builder.setCombinatorType(Protocol.CombinatorType.ALL_SUCCEEDED_OR_FIRST_FAILED);
346+
children = asff.children();
347+
} else {
348+
throw new IllegalStateException("Unknown UnresolvedFuture type: " + unresolved);
349+
}
350+
351+
for (UnresolvedFuture child : children) {
352+
if (child instanceof UnresolvedFuture.Single s) {
353+
pushHandle(builder, s.handle());
354+
} else {
355+
builder.addNestedFutures(resolveUnresolvedFuture(child));
356+
}
357+
}
358+
359+
return builder.build();
360+
}
361+
362+
private void pushHandle(Protocol.Future.Builder builder, int handle) {
363+
NotificationId id = handleMapping.get(handle);
364+
if (id == null) {
365+
return;
366+
}
367+
if (id instanceof NotificationId.CompletionId cid) {
368+
builder.addWaitingCompletions(cid.id());
369+
} else if (id instanceof NotificationId.SignalId sid) {
370+
builder.addWaitingSignals(sid.id());
371+
} else if (id instanceof NotificationId.SignalName sn) {
372+
builder.addWaitingNamedSignals(sn.name());
373+
}
374+
}
375+
376+
// --- Inner types ---
377+
378+
sealed interface ResolveFutureResult
379+
permits ResolveFutureResult.AnyCompleted, ResolveFutureResult.WaitExternalInput {
380+
381+
ResolveFutureResult ANY_COMPLETED = new AnyCompleted();
382+
383+
record AnyCompleted() implements ResolveFutureResult {}
384+
385+
record WaitExternalInput(UnresolvedFuture remaining) implements ResolveFutureResult {}
386+
}
387+
388+
private enum HandleState {
389+
SUCCEEDED,
390+
FAILED,
391+
PENDING;
392+
393+
boolean isCompleted() {
394+
return this == SUCCEEDED || this == FAILED;
395+
}
396+
}
397+
398+
/**
399+
* Wrapper for the internal resolution result. A sentinel {@link #SHORT_CIRCUITED} value signals
400+
* that a nested combinator completed and the loop should stop.
401+
*/
402+
private static final class TryResolveResult {
403+
static final TryResolveResult SHORT_CIRCUITED = new TryResolveResult(null);
404+
static final TryResolveResult PENDING = new TryResolveResult(HandleState.PENDING);
405+
406+
private final HandleState state;
407+
408+
private TryResolveResult(HandleState state) {
409+
this.state = state;
410+
}
411+
412+
HandleState handleState() {
413+
if (state == null) {
414+
throw new IllegalStateException("SHORT_CIRCUITED has no HandleState");
415+
}
416+
return state;
417+
}
418+
}
131419
}

0 commit comments

Comments
 (0)