Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,35 @@ jobs:
${{ github.workspace }}/target/surefire-reports/TEST-*.xml
if-no-files-found: warn

# ---------------------------------------------------------------------------
# vmlens interleaving analysis — pure-Java, needs no native library or models.
# Staged to a single smoke test for now (see the `vmlens` profile in pom.xml).
# ---------------------------------------------------------------------------
vmlens:
name: Test (vmlens interleavings)
needs: startgate
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-java@v5
with:
distribution: 'temurin'
java-version: ${{ env.JAVA_VERSION }}
cache: maven
- name: Test under vmlens (interleaving analysis)
# Add each new test in the `vmlens` package to this -Dtest list (surefire
# -Dtest matches simple class names, not package globs; the default suite is
# excluded from the vmlens package via pom.xml managed surefire <excludes>).
run: >-
mvn --batch-mode --no-transfer-progress -Pvmlens test
-Dtest=VmlensInterleavingSmokeTest,SessionStateInterleavingTest -DfailIfNoTests=false
- uses: actions/upload-artifact@v7
if: always()
with:
name: vmlens-report
path: target/vmlens-report/
if-no-files-found: ignore

test-java-macos-arm64-metal:
name: Java Tests macOS 14 arm64 (Metal)
needs: build-macos-arm64-metal
Expand Down
53 changes: 36 additions & 17 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,21 @@ SPDX-License-Identifier: MIT
<version>${reactor.version}</version>
<scope>test</scope>
</dependency>
<!--
vmlens interleaving-analysis API. Test-scoped and transitive-dependency-free
(its own deps are all test-scope, hence not propagated), so it is safe on the
default test classpath under dependencyConvergence. Needed so the
VmlensInterleavingSmoke* test compiles in every build; it is only *executed*
under the `vmlens` profile (see that profile + the vmlens CI job). The default
surefire run excludes it (managed surefire <excludes> below) so the vmlens
agent's JDK 9 API class is never loaded outside the agent-driven run.
-->
<dependency>
<groupId>com.vmlens</groupId>
<artifactId>api</artifactId>
<version>${vmlens.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -278,6 +293,19 @@ SPDX-License-Identifier: MIT
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.5.6</version>
<configuration>
<!--
Tests in the `vmlens` package are meaningful only when run
under the vmlens agent (the `vmlens` profile / CI job). Without the
agent its AllInterleavings loop body never executes (a vacuous pass
that also prints an "agent not configured" warning), so exclude it
from the ordinary surefire run. The vmlens job re-includes them with
`-Dtest=...`, which overrides this exclude.
-->
<excludes>
<exclude>**/vmlens/*.java</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down Expand Up @@ -881,31 +909,22 @@ SPDX-License-Identifier: MIT
</profile>
<profile>
<id>vmlens</id>
<dependencies>
<dependency>
<groupId>com.vmlens</groupId>
<artifactId>api</artifactId>
<version>${vmlens.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>com.vmlens</groupId>
<artifactId>vmlens-maven-plugin</artifactId>
<configuration>
<!--
Lincheck generates its own TestThreadExecution class on the fly.
That bytecode clashes with vmlens's load-time instrumentation
(java.lang.VerifyError). Skip the Lincheck test under vmlens;
the default test job still runs it.
**/*$* is the plugin default - kept to preserve inner-class skip.
Run vmlens interleaving analysis over the whole `vmlens` test
package (smoke test + SessionStateInterleavingTest). The `com.vmlens:api`
test dependency lives in the main <dependencies> block.
Expand <includes> as more concurrency tests are added (the
streambuffer repo runs vmlens over its whole suite).
-->
<excludes>
<exclude>**/*$*</exclude>
<exclude>**/CancellationTokenLincheckTest.java</exclude>
</excludes>
<includes>
<include>**/vmlens/*.java</include>
</includes>
</configuration>
<executions>
<execution>
Expand Down
132 changes: 42 additions & 90 deletions src/main/java/net/ladenthin/llama/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import lombok.ToString;
import net.ladenthin.llama.parameters.InferenceParameters;
import net.ladenthin.llama.value.ChatMessage;
import net.ladenthin.llama.value.ChatTranscript;
import net.ladenthin.llama.value.Pair;
import org.jspecify.annotations.Nullable;

/**
Expand All @@ -30,11 +30,11 @@
* {@link #commitStreamedReply(String)}.
* </p>
*
* <p>{@code toString} is generated by Lombok over the slot id, system message, and
* accumulated turns. The owning {@link LlamaModel} is excluded because its
* {@code toString} would render native state. The {@code paramsCustomizer}
* {@link UnaryOperator} is excluded because lambda {@code toString} is the implementation
* hash, not useful in logs. The intrinsic {@code lock} is excluded as a noise field.
* <p>{@code toString} is generated by Lombok over the slot id and the
* {@link SessionState} (which renders the transcript and streaming flag). The owning
* {@link LlamaModel} is excluded because its {@code toString} would render native
* state. The {@code paramsCustomizer} {@link UnaryOperator} is excluded because lambda
* {@code toString} is the implementation hash, not useful in logs.
* {@code equals}/{@code hashCode} are intentionally NOT generated: a session is a
* mutable lifecycle handle managed by identity.</p>
*/
Expand All @@ -48,23 +48,18 @@ public final class Session implements AutoCloseable {
private final int slotId;

/**
* Append-only transcript with two-phase commit semantics. See the
* {@link net.ladenthin.llama.value.ChatTranscript} class Javadoc for the full invariant statement
* and the {@code ChatTranscriptTest} class for the running-documentation
* tests that pin the contract.
* The lock-guarded streaming-flag + transcript state machine. Extracted to
* {@link SessionState} so its concurrency contract (the two-phase commit and the
* streaming guard) is testable without the native model; see that class and the
* {@code SessionStateInterleavingTest}. This {@code Session} only adds the model
* calls, injected as callbacks that {@link SessionState} runs under its lock.
*/
private final ChatTranscript transcript;
private final SessionState state;

// Lambda UnaryOperator — toString is the implementation hash, not useful in logs.
@ToString.Exclude
private final @Nullable UnaryOperator<InferenceParameters> paramsCustomizer;

// Intrinsic lock used only for synchronisation; rendering its identity adds noise.
@ToString.Exclude
private final Object lock = new Object();

private boolean streamingActive;

/**
* Create a session bound to a specific slot id, with an optional system prompt
* applied to every {@link #send(String)} call.
Expand Down Expand Up @@ -95,7 +90,7 @@ public Session(
@Nullable UnaryOperator<InferenceParameters> paramsCustomizer) {
this.model = model;
this.slotId = slotId;
this.transcript = new ChatTranscript(systemMessage);
this.state = new SessionState(slotId, systemMessage);
this.paramsCustomizer = paramsCustomizer;
}

Expand All @@ -106,22 +101,13 @@ public Session(
* @return the assistant's reply text
*/
public String send(String userMessage) {
synchronized (lock) {
if (streamingActive) {
throw new IllegalStateException("stream in progress on slot " + slotId
+ " (transcript=" + transcript.size() + " turns)"
+ "; call commitStreamedReply(...) before send(...)");
}
// Two-phase commit: build the wire-format with the pending user turn
// outside the transcript via messagesWithPendingUserTurn(...). On
// model success, commit BOTH turns atomically through appendRound(...).
// On model failure, nothing was committed — no rollback logic needed.
// Invariant pinned by ChatTranscriptTest.
InferenceParameters params = buildParamsWithPendingUserTurn(userMessage);
String reply = model.chatCompleteText(params);
transcript.appendRound(userMessage, reply);
return reply;
}
// Two-phase commit lives in SessionState.send(...): it guards against an
// in-progress stream, builds the wire-format with the pending user turn, runs
// the model call below under the lock, and on success commits BOTH turns
// atomically. On model failure nothing is committed — no rollback needed.
return state.send(
userMessage,
(systemMessage, wireMessages) -> model.chatCompleteText(buildParams(systemMessage, wireMessages)));
}

/**
Expand All @@ -134,20 +120,13 @@ public String send(String userMessage) {
* @return a {@link LlamaIterable} that yields assistant reply chunks
*/
public LlamaIterable stream(String userMessage) {
synchronized (lock) {
if (streamingActive) {
throw new IllegalStateException("stream in progress on slot " + slotId
+ " (transcript=" + transcript.size() + " turns)"
+ "; call commitStreamedReply(...) before stream(...)");
}
// Two-phase commit: see send(). The user turn is committed only after
// generateChat successfully returns the iterable; the assistant turn is
// committed separately by commitStreamedReply(...).
LlamaIterable iterable = model.generateChat(buildParamsWithPendingUserTurn(userMessage));
transcript.appendUserTurn(userMessage);
streamingActive = true;
return iterable;
}
// SessionState.beginStream(...) guards against an in-progress stream, runs the
// model call below under the lock, and on success commits the user turn and
// marks streaming active; the assistant turn is committed separately by
// commitStreamedReply(...).
return state.beginStream(
userMessage,
(systemMessage, wireMessages) -> model.generateChat(buildParams(systemMessage, wireMessages)));
}

/**
Expand All @@ -157,15 +136,7 @@ public LlamaIterable stream(String userMessage) {
* @param assistantText the assistant text accumulated from a prior {@link #stream(String)} call
*/
public void commitStreamedReply(String assistantText) {
synchronized (lock) {
if (!streamingActive) {
throw new IllegalStateException("no stream in progress on slot " + slotId
+ " (transcript=" + transcript.size() + " turns)"
+ "; call stream(...) first");
}
transcript.appendAssistantTurn(assistantText);
streamingActive = false;
}
state.commitStreamedReply(assistantText);
}

/**
Expand All @@ -175,14 +146,7 @@ public void commitStreamedReply(String assistantText) {
* @return the JSON response from the native save action
*/
public String save(String filepath) {
synchronized (lock) {
if (streamingActive) {
throw new IllegalStateException("stream in progress on slot " + slotId
+ " (transcript=" + transcript.size() + " turns)"
+ "; call commitStreamedReply(...) before save(...)");
}
return model.saveSlot(slotId, filepath);
}
return state.runWhenNotStreaming("save", () -> model.saveSlot(slotId, filepath));
}

/**
Expand All @@ -192,48 +156,36 @@ public String save(String filepath) {
* @return the JSON response from the native restore action
*/
public String restore(String filepath) {
synchronized (lock) {
if (streamingActive) {
throw new IllegalStateException("stream in progress on slot " + slotId
+ " (transcript=" + transcript.size() + " turns)"
+ "; call commitStreamedReply(...) before restore(...)");
}
return model.restoreSlot(slotId, filepath);
}
return state.runWhenNotStreaming("restore", () -> model.restoreSlot(slotId, filepath));
}

/**
* Transcript accessor.
* @return the accumulated transcript so far, in order, including the system message if any
*/
public List<ChatMessage> getMessages() {
synchronized (lock) {
return transcript.snapshot();
}
return state.snapshot();
}

/** Erase the bound slot's KV cache. Does not modify the in-memory transcript. */
@Override
public void close() {
synchronized (lock) {
model.eraseSlot(slotId);
}
state.runUnderLock(() -> model.eraseSlot(slotId));
}

/**
* Build inference parameters with a pending user turn appended to the existing
* transcript — without mutating the underlying {@link net.ladenthin.llama.value.ChatTranscript}. The
* actual transcript mutation happens AFTER the model call returns successfully,
* either via {@link net.ladenthin.llama.value.ChatTranscript#appendRound(String, String)} (send path)
* or {@link net.ladenthin.llama.value.ChatTranscript#appendUserTurn(String)} (stream path).
* Build inference parameters from the system message and the wire-format messages
* supplied by {@link SessionState} (the committed turns plus the pending user
* turn), applying the optional {@code paramsCustomizer}. The transcript itself is
* never mutated here; {@link SessionState} commits turns only after the model call
* returns successfully.
*
* @param pendingUserMessage the user turn to include in the wire format
* @return inference parameters carrying transcript + pending user turn
* @param systemMessage the system prompt, or {@code null} when none was configured
* @param wireMessages the committed turns plus the pending user turn
* @return inference parameters carrying the system message + wire messages
*/
private InferenceParameters buildParamsWithPendingUserTurn(String pendingUserMessage) {
InferenceParameters params = InferenceParameters.empty()
.withMessages(
transcript.getSystemMessage(), transcript.messagesWithPendingUserTurn(pendingUserMessage));
private InferenceParameters buildParams(@Nullable String systemMessage, List<Pair<String, String>> wireMessages) {
InferenceParameters params = InferenceParameters.empty().withMessages(systemMessage, wireMessages);
return paramsCustomizer == null ? params : paramsCustomizer.apply(params);
}
}
Loading
Loading