Skip to content

Commit 1c56be5

Browse files

File tree

9 files changed

+310
-18
lines changed

9 files changed

+310
-18
lines changed

.claude/skills/translate-from-shared-core/SKILL.md

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ These are structurally very similar between both codebases:
130130
| `src/vm/transitions/journal.rs` | Methods across `ReplayingState.java` and `ProcessingState.java` |
131131
| `src/vm/transitions/async_results.rs` | Methods in `ReplayingState.java`, `ProcessingState.java`, `AsyncResultsState.java` |
132132
| `src/vm/transitions/terminal.rs` | `hitError()`/`hitSuspended()` methods on `State` interface |
133-
| `src/vm/errors.rs` | `ProtocolException.java` |
133+
| `src/vm/errors.rs` | `ProtocolException.java` (factory methods, not separate error classes) |
134+
| `src/error.rs` (CommandMetadata, NotificationMetadata) | `CommandMetadata.java` (record); notification metadata is built as strings inline |
134135
| `src/service_protocol/` | `MessageDecoder.java`, `MessageEncoder.java`, `MessageType.java`, `ServiceProtocol.java` |
135136

136137
### Command processing patterns
@@ -165,6 +166,39 @@ Java test inputs are built with `ProtoUtils` helpers (`startMessage()`, `inputCm
165166

166167
**Key implication**: When a Rust commit adds a new VM-level test, in Java you typically need to add a handler-level test in the appropriate test suite, not a direct state machine test.
167168

169+
### Test translation details
170+
171+
When translating Rust VM tests to Java:
172+
173+
1. **Identify the right test suite**: Match the Rust test module to the Java abstract test suite:
174+
- `src/tests/failures.rs` (journal_mismatch) → `StateMachineFailuresTestSuite`
175+
- `src/tests/async_result.rs``AsyncResultTestSuite`
176+
- `src/tests/run.rs``SideEffectTestSuite`
177+
- `src/tests/state.rs``StateTestSuite` / `EagerStateTestSuite`
178+
179+
2. **Add abstract method + test definition**: Add the abstract handler method to the suite, then add test definitions using `withInput(...)` and assertion patterns like `assertingOutput(containsOnly(errorMessage(...)))` or `expectingOutput(...)`.
180+
181+
3. **Implement in both Java and Kotlin**: The suite is extended in both:
182+
- `sdk-core/src/test/java/dev/restate/sdk/core/javaapi/<TestName>.java`
183+
- `sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/<TestName>.kt`
184+
185+
4. **Kotlin API differences**:
186+
- Must `import dev.restate.sdk.kotlin.*` for reified extension functions (`runAsync`, `runBlock`, etc.)
187+
- `ctx.runAsync<String>(name) { ... }` (reified, not `ctx.runAsync(name, String.class, ...)`)
188+
- `ctx.awakeable(TestSerdes.STRING)` (needs a serde, not `String::class.java`)
189+
- `ctx.timer(0.milliseconds)` (uses Kotlin Duration)
190+
- Handler factories: `testDefinitionForService<Unit, String?>("Name") { ctx, _: Unit -> ... }`
191+
192+
5. **Cancel signal is always included**: The `HandlerContextImpl` automatically appends `CANCEL_HANDLE` (handle=1, mapping to `SignalId(1)`) to every `doProgress` call. This matches Rust's `CoreVM.do_progress` which appends `cancel_signal_handle`. So in test assertions, the cancel signal notification ID will always be part of the awaited notifications.
193+
194+
6. **ProtoUtils helpers**: Use `startMessage(n)`, `inputCmd()`, `runCmd(completionId, name)`, `suspensionMessage(completionIds...)`, etc. For messages without helpers (e.g., `SleepCommandMessage`, `SleepCompletionNotificationMessage`), build them directly with the protobuf builders.
195+
196+
### Shared utilities
197+
198+
- `Util.awakeableIdStr(invocationId, signalId)` — computes the awakeable ID string from invocation ID and signal ID. Used in both `StateMachineImpl` (for creating awakeables) and `ReplayingState` (for error messages).
199+
- `StateMachineImpl.CANCEL_SIGNAL_ID` — the signal ID for the built-in cancel signal (value: 1). Package-private, available via static import.
200+
- Java 17 target — do NOT use switch pattern matching (`case Type t ->`) in Java source; use `instanceof` chains instead.
201+
168202
## Step 4: Apply the translation
169203

170204
1. **Read the affected Java files first** before making changes

sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import dev.restate.sdk.common.TerminalException;
1313
import dev.restate.sdk.core.generated.protocol.Protocol;
1414
import dev.restate.sdk.core.statemachine.NotificationId;
15+
import java.util.List;
16+
import java.util.Map;
1517

1618
public class ProtocolException extends RuntimeException {
1719

@@ -133,6 +135,31 @@ public static ProtocolException unauthorized(Throwable e) {
133135
return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e);
134136
}
135137

138+
public static ProtocolException uncompletedDoProgressDuringReplay(
139+
List<NotificationId> sortedNotificationIds,
140+
Map<NotificationId, String> notificationDescriptions) {
141+
var sb = new StringBuilder();
142+
sb.append(
143+
"Found a mismatch between the code paths taken during the previous execution and the paths taken during this execution.\n");
144+
sb.append(
145+
"'Awaiting a future' could not be replayed. This usually means the code was mutated adding an 'await' without registering a new service revision.\n");
146+
sb.append("Notifications awaited on this await point:");
147+
for (var notificationId : sortedNotificationIds) {
148+
sb.append("\n - ");
149+
String description = notificationDescriptions.get(notificationId);
150+
if (description != null) {
151+
sb.append(description);
152+
} else if (notificationId instanceof NotificationId.CompletionId completionId) {
153+
sb.append("completion id ").append(completionId.id());
154+
} else if (notificationId instanceof NotificationId.SignalId signalId) {
155+
sb.append("signal [").append(signalId.id()).append("]");
156+
} else if (notificationId instanceof NotificationId.SignalName signalName) {
157+
sb.append("signal '").append(signalName.name()).append("'");
158+
}
159+
}
160+
return new ProtocolException(sb.toString(), JOURNAL_MISMATCH_CODE);
161+
}
162+
136163
public static ProtocolException unsupportedFeature(
137164
String featureName,
138165
Protocol.ServiceProtocolVersion requiredVersion,

sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
99
package dev.restate.sdk.core.statemachine;
1010

11+
import static dev.restate.sdk.core.statemachine.StateMachineImpl.CANCEL_SIGNAL_ID;
1112
import static dev.restate.sdk.core.statemachine.Util.byteStringToSlice;
1213

1314
import com.google.protobuf.ByteString;
@@ -19,13 +20,45 @@
1920
import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse;
2021
import java.util.*;
2122
import java.util.concurrent.CompletableFuture;
23+
import java.util.stream.Collectors;
2224
import org.apache.logging.log4j.LogManager;
2325
import org.apache.logging.log4j.Logger;
2426

2527
final class ReplayingState implements State {
2628

2729
private static final Logger LOG = LogManager.getLogger(ReplayingState.class);
2830

31+
/**
32+
* Comparator for notification IDs in error messages. Orders: completions first (by id), then
33+
* named signals (by name), then signal IDs (by id, with cancel signal last).
34+
*/
35+
private static final Comparator<NotificationId> NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH =
36+
Comparator.<NotificationId>comparingInt(
37+
id -> {
38+
if (id instanceof NotificationId.CompletionId) return 0;
39+
if (id instanceof NotificationId.SignalName) return 1;
40+
return 2;
41+
})
42+
.thenComparing(
43+
(a, b) -> {
44+
if (a instanceof NotificationId.CompletionId ac
45+
&& b instanceof NotificationId.CompletionId bc) {
46+
return Integer.compare(ac.id(), bc.id());
47+
}
48+
if (a instanceof NotificationId.SignalName an
49+
&& b instanceof NotificationId.SignalName bn) {
50+
return an.name().compareTo(bn.name());
51+
}
52+
if (a instanceof NotificationId.SignalId as_
53+
&& b instanceof NotificationId.SignalId bs) {
54+
boolean aIsCancel = as_.id() == CANCEL_SIGNAL_ID;
55+
boolean bIsCancel = bs.id() == CANCEL_SIGNAL_ID;
56+
if (aIsCancel != bIsCancel) return aIsCancel ? 1 : -1;
57+
return Integer.compare(as_.id(), bs.id());
58+
}
59+
return 0;
60+
});
61+
2962
private final Deque<MessageLite> commandsToProcess;
3063
private final AsyncResultsState asyncResultsState;
3164
private final RunState runState;
@@ -68,12 +101,65 @@ public DoProgressResponse doProgress(List<Integer> awaitingOn, StateContext stat
68101
return DoProgressResponse.AnyCompleted.INSTANCE;
69102
}
70103

71-
if (stateContext.isInputClosed()) {
72-
this.hitSuspended(notificationIds, stateContext);
73-
ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE);
104+
// This assertion proves the user mutated the code, adding an await point.
105+
//
106+
// During replay, we transition to processing AFTER replaying all COMMANDS.
107+
// If we reach this point, none of the previous checks succeeded, meaning we don't have
108+
// enough notifications to complete this await point. But if this await cannot be completed
109+
// during replay, then no progress should have been made afterward, meaning there should be
110+
// no more commands to replay. However, we ARE still replaying, which means there ARE commands
111+
// to replay after this await point.
112+
//
113+
// This contradiction proves the code was mutated: an await must have been added after
114+
// the journal was originally created.
115+
116+
// Prepare error metadata to make it easier to debug
117+
Map<NotificationId, String> knownNotificationMetadata = new HashMap<>();
118+
CommandRelationship relatedCommand = null;
119+
120+
// Collect run info
121+
for (int handle : awaitingOn) {
122+
RunState.Run runInfo = runState.getRunInfo(handle);
123+
if (runInfo != null) {
124+
var notifId = asyncResultsState.mustResolveNotificationHandle(handle);
125+
knownNotificationMetadata.put(
126+
notifId,
127+
MessageType.RunCommandMessage.name()
128+
+ " '"
129+
+ runInfo.commandName()
130+
+ "' (command index "
131+
+ runInfo.commandIndex()
132+
+ ")");
133+
relatedCommand =
134+
new CommandRelationship.Specific(
135+
runInfo.commandIndex(), CommandType.RUN, runInfo.commandName());
136+
}
137+
}
138+
139+
// For awakeables and cancellation, add descriptions
140+
for (var notifId : notificationIds) {
141+
if (notifId instanceof NotificationId.SignalId signalId) {
142+
if (signalId.id() == CANCEL_SIGNAL_ID) {
143+
knownNotificationMetadata.put(notifId, "Cancellation");
144+
} else if (signalId.id() > 16) {
145+
knownNotificationMetadata.put(
146+
notifId,
147+
"Awakeable " + Util.awakeableIdStr(stateContext.getStartInfo().id(), signalId.id()));
148+
}
149+
}
74150
}
75151

76-
return DoProgressResponse.ReadFromInput.INSTANCE;
152+
this.hitError(
153+
ProtocolException.uncompletedDoProgressDuringReplay(
154+
notificationIds.stream()
155+
.sorted(NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH)
156+
.collect(Collectors.toList()),
157+
knownNotificationMetadata),
158+
relatedCommand,
159+
null,
160+
stateContext);
161+
ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE);
162+
return null; // unreachable
77163
}
78164

79165
@Override

sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ public void insertRunToExecute(int handle, int commandIndex, String commandName)
3232
return null;
3333
}
3434

35+
public @Nullable Run getRunInfo(int handle) {
36+
return runs.get(handle);
37+
}
38+
3539
public boolean anyExecuting(Collection<Integer> anyHandle) {
3640
return anyHandle.stream()
3741
.anyMatch(h -> runs.containsKey(h) && runs.get(h).state == RunStateInner.Executing);

sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import dev.restate.sdk.core.ProtocolException;
2020
import dev.restate.sdk.core.generated.protocol.Protocol;
2121
import dev.restate.sdk.endpoint.HeadersAccessor;
22-
import java.nio.ByteBuffer;
2322
import java.time.Duration;
2423
import java.time.Instant;
2524
import java.util.*;
@@ -34,8 +33,7 @@
3433
class StateMachineImpl implements StateMachine {
3534

3635
private static final Logger LOG = LogManager.getLogger(StateMachineImpl.class);
37-
private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1";
38-
private static final int CANCEL_SIGNAL_ID = 1;
36+
static final int CANCEL_SIGNAL_ID = 1;
3937

4038
// Callbacks
4139
private final CompletableFuture<Void> waitForReadyFuture = new CompletableFuture<>();
@@ -385,15 +383,7 @@ public Awakeable awakeable() {
385383
.createSignalHandle(new NotificationId.SignalId(signalId), this.stateContext);
386384

387385
// Encode awakeable id
388-
String awakeableId =
389-
AWAKEABLE_IDENTIFIER_PREFIX
390-
+ Base64.getUrlEncoder()
391-
.encodeToString(
392-
this.stateContext
393-
.getStartInfo()
394-
.id()
395-
.concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip()))
396-
.toByteArray());
386+
String awakeableId = Util.awakeableIdStr(this.stateContext.getStartInfo().id(), signalId);
397387

398388
return new Awakeable(awakeableId, signalHandle);
399389
}

sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import dev.restate.sdk.core.generated.protocol.Protocol;
1717
import java.nio.ByteBuffer;
1818
import java.time.Duration;
19+
import java.util.Base64;
1920
import java.util.Map;
2021
import java.util.Objects;
2122
import java.util.stream.Collectors;
@@ -78,6 +79,17 @@ static Duration durationMin(Duration a, Duration b) {
7879
return (a.compareTo(b) <= 0) ? a : b;
7980
}
8081

82+
private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1";
83+
84+
static String awakeableIdStr(ByteString invocationId, int signalId) {
85+
return AWAKEABLE_IDENTIFIER_PREFIX
86+
+ Base64.getUrlEncoder()
87+
.encodeToString(
88+
invocationId
89+
.concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip()))
90+
.toByteArray());
91+
}
92+
8193
/**
8294
* Returns a string representation of a command message.
8395
*

sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder;
1313
import static dev.restate.sdk.core.statemachine.ProtoUtils.*;
1414
import static org.assertj.core.api.Assertions.assertThat;
15+
import static org.assertj.core.api.InstanceOfAssertFactories.STRING;
1516

1617
import dev.restate.sdk.core.generated.protocol.Protocol;
1718
import dev.restate.serde.Serde;
@@ -25,6 +26,12 @@ public abstract class StateMachineFailuresTestSuite implements TestDefinitions.T
2526

2627
protected abstract TestInvocationBuilder sideEffectFailure(Serde<Integer> serde);
2728

29+
protected abstract TestInvocationBuilder awaitRunAfterProgressWasMade();
30+
31+
protected abstract TestInvocationBuilder awaitSleepAfterProgressWasMade();
32+
33+
protected abstract TestInvocationBuilder awaitAwakeableAfterProgressWasMade();
34+
2835
private static final Serde<Integer> FAILING_SERIALIZATION_INTEGER_TYPE_TAG =
2936
Serde.using(
3037
i -> {
@@ -91,6 +98,72 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
9198
.assertingOutput(
9299
containsOnly(
93100
errorDescriptionStartingWith(IllegalStateException.class.getCanonicalName())))
94-
.named("Serde deserialization error"));
101+
.named("Serde deserialization error"),
102+
// --- Uncompleted doProgress during replay (bad await) tests
103+
this.awaitRunAfterProgressWasMade()
104+
.withInput(
105+
startMessage(4),
106+
inputCmd(),
107+
runCmd(1, "my-side-effect"),
108+
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(),
109+
Protocol.SleepCompletionNotificationMessage.newBuilder()
110+
.setCompletionId(2)
111+
.setVoid(Protocol.Void.getDefaultInstance())
112+
.build())
113+
.assertingOutput(
114+
containsOnly(
115+
errorMessage(
116+
errorMessage ->
117+
assertThat(errorMessage)
118+
.returns(
119+
ProtocolException.JOURNAL_MISMATCH_CODE,
120+
Protocol.ErrorMessage::getCode)
121+
.extracting(Protocol.ErrorMessage::getMessage, STRING)
122+
.contains("could not be replayed")
123+
.contains("await"))))
124+
.named("Add await on run after progress was made"),
125+
this.awaitSleepAfterProgressWasMade()
126+
.withInput(
127+
startMessage(4),
128+
inputCmd(),
129+
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(1).build(),
130+
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(),
131+
Protocol.SleepCompletionNotificationMessage.newBuilder()
132+
.setCompletionId(2)
133+
.setVoid(Protocol.Void.getDefaultInstance())
134+
.build())
135+
.assertingOutput(
136+
containsOnly(
137+
errorMessage(
138+
errorMessage ->
139+
assertThat(errorMessage)
140+
.returns(
141+
ProtocolException.JOURNAL_MISMATCH_CODE,
142+
Protocol.ErrorMessage::getCode)
143+
.extracting(Protocol.ErrorMessage::getMessage, STRING)
144+
.contains("could not be replayed")
145+
.contains("await"))))
146+
.named("Add await on sleep after progress was made"),
147+
this.awaitAwakeableAfterProgressWasMade()
148+
.withInput(
149+
startMessage(3),
150+
inputCmd(),
151+
Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(),
152+
Protocol.SleepCompletionNotificationMessage.newBuilder()
153+
.setCompletionId(2)
154+
.setVoid(Protocol.Void.getDefaultInstance())
155+
.build())
156+
.assertingOutput(
157+
containsOnly(
158+
errorMessage(
159+
errorMessage ->
160+
assertThat(errorMessage)
161+
.returns(
162+
ProtocolException.JOURNAL_MISMATCH_CODE,
163+
Protocol.ErrorMessage::getCode)
164+
.extracting(Protocol.ErrorMessage::getMessage, STRING)
165+
.contains("could not be replayed")
166+
.contains("await"))))
167+
.named("Add await on awakeable after progress was made"));
95168
}
96169
}

0 commit comments

Comments
 (0)