Skip to content

Commit e4f531c

Browse files
committed
Add Session multi-turn helper + ChatMessage value type (§2.6)
Session is a thin wrapper over LlamaModel: it owns a slot id, an accumulating user/assistant transcript, and an optional system message and parameter customizer. send(userMessage) appends both sides of the turn and runs chatCompleteText with the full history. stream(userMessage) returns a LlamaIterable for streamed replies; commitStreamedReply records the assistant turn once the caller has accumulated the text. save/restore delegate to existing LlamaModel.saveSlot/restoreSlot. close() erases the slot's KV cache. Single-threaded use only in this pass — per-session locking is the M-effort follow-up. ChatMessage is the minimal value type for the transcript; will be reused by ChatResponse when §2.2 lands. https://claude.ai/code/session_01R4ZrEy3ptJDLuUgUKuM4Gy
1 parent 1e673a9 commit e4f531c

4 files changed

Lines changed: 214 additions & 0 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
package net.ladenthin.llama;
6+
7+
/**
8+
* A single message in a chat conversation: a role ({@code "user"}, {@code "assistant"},
9+
* or {@code "system"}) and its textual content. Used by {@link Session} to accumulate
10+
* conversation turns.
11+
*/
12+
public final class ChatMessage {
13+
14+
private final String role;
15+
private final String content;
16+
17+
public ChatMessage(String role, String content) {
18+
this.role = role;
19+
this.content = content;
20+
}
21+
22+
public String getRole() {
23+
return role;
24+
}
25+
26+
public String getContent() {
27+
return content;
28+
}
29+
30+
@Override
31+
public String toString() {
32+
return role + ": " + content;
33+
}
34+
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
package net.ladenthin.llama;
6+
7+
import java.util.ArrayList;
8+
import java.util.Collections;
9+
import java.util.List;
10+
import java.util.function.Consumer;
11+
12+
/**
13+
* Thin multi-turn conversation wrapper over a {@link LlamaModel} slot. Maintains an
14+
* accumulating list of {@link ChatMessage} turns and forwards each {@link #send(String)}
15+
* to the underlying chat-completion API with the full transcript so far. KV-cache state
16+
* for the bound slot can be persisted via {@link #save(String)} and restored with
17+
* {@link #restore(String)}, which delegate to {@link LlamaModel#saveSlot(int, String)}
18+
* and {@link LlamaModel#restoreSlot(int, String)}.
19+
* <p>
20+
* This wrapper is intentionally not thread-safe; callers must serialize access to a
21+
* single {@code Session} instance. Concurrency support is a follow-up (M-effort) item.
22+
* </p>
23+
*/
24+
public final class Session implements AutoCloseable {
25+
26+
private final LlamaModel model;
27+
private final int slotId;
28+
private final String systemMessage;
29+
private final List<Pair<String, String>> turns = new ArrayList<Pair<String, String>>();
30+
private final Consumer<InferenceParameters> paramsCustomizer;
31+
32+
/**
33+
* Create a session bound to a specific slot id, with an optional system prompt
34+
* applied to every {@link #send(String)} call.
35+
*
36+
* @param model the underlying model
37+
* @param slotId the slot id used by {@link #save(String)} / {@link #restore(String)}
38+
* @param systemMessage optional system prompt (may be {@code null} or empty)
39+
*/
40+
public Session(LlamaModel model, int slotId, String systemMessage) {
41+
this(model, slotId, systemMessage, null);
42+
}
43+
44+
/**
45+
* Create a session with a customizer that gets to mutate the
46+
* {@link InferenceParameters} for every call (e.g. set temperature, n_predict).
47+
*
48+
* @param model the underlying model
49+
* @param slotId the slot id
50+
* @param systemMessage optional system prompt
51+
* @param paramsCustomizer applied to each request's parameters; may be {@code null}
52+
*/
53+
public Session(LlamaModel model, int slotId, String systemMessage,
54+
Consumer<InferenceParameters> paramsCustomizer) {
55+
this.model = model;
56+
this.slotId = slotId;
57+
this.systemMessage = systemMessage;
58+
this.paramsCustomizer = paramsCustomizer;
59+
}
60+
61+
/** Send a user message and return the assistant's text reply, appending both to the transcript. */
62+
public String send(String userMessage) {
63+
turns.add(new Pair<String, String>("user", userMessage));
64+
InferenceParameters params = buildParams();
65+
String reply = model.chatCompleteText(params);
66+
turns.add(new Pair<String, String>("assistant", reply));
67+
return reply;
68+
}
69+
70+
/**
71+
* Streaming variant of {@link #send(String)}. The returned iterable yields chunks of
72+
* the assistant reply; consume it fully (or via try-with-resources) before calling
73+
* {@link #send(String)} again, because the assistant turn is only appended to the
74+
* transcript when the caller invokes {@link #commitStreamedReply(String)}.
75+
*/
76+
public LlamaIterable stream(String userMessage) {
77+
turns.add(new Pair<String, String>("user", userMessage));
78+
return model.generateChat(buildParams());
79+
}
80+
81+
/**
82+
* Record an assistant reply that was produced by a previous {@link #stream(String)}
83+
* call. Called by the caller after it has accumulated the streamed text.
84+
*/
85+
public void commitStreamedReply(String assistantText) {
86+
turns.add(new Pair<String, String>("assistant", assistantText));
87+
}
88+
89+
/** Save this session's slot KV cache to {@code filepath}. */
90+
public String save(String filepath) {
91+
return model.saveSlot(slotId, filepath);
92+
}
93+
94+
/** Restore this session's slot KV cache from {@code filepath}. */
95+
public String restore(String filepath) {
96+
return model.restoreSlot(slotId, filepath);
97+
}
98+
99+
/** The accumulated turns so far, in order. */
100+
public List<ChatMessage> getMessages() {
101+
List<ChatMessage> out = new ArrayList<ChatMessage>(turns.size() + 1);
102+
if (systemMessage != null && !systemMessage.isEmpty()) {
103+
out.add(new ChatMessage("system", systemMessage));
104+
}
105+
for (Pair<String, String> p : turns) {
106+
out.add(new ChatMessage(p.getKey(), p.getValue()));
107+
}
108+
return Collections.unmodifiableList(out);
109+
}
110+
111+
/** Erase the bound slot's KV cache. Does not modify the in-memory transcript. */
112+
@Override
113+
public void close() {
114+
model.eraseSlot(slotId);
115+
}
116+
117+
private InferenceParameters buildParams() {
118+
InferenceParameters params = new InferenceParameters("")
119+
.setMessages(systemMessage, new ArrayList<Pair<String, String>>(turns));
120+
if (paramsCustomizer != null) {
121+
paramsCustomizer.accept(params);
122+
}
123+
return params;
124+
}
125+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
package net.ladenthin.llama;
6+
7+
import org.junit.Test;
8+
9+
import static org.junit.Assert.assertEquals;
10+
11+
@ClaudeGenerated(
12+
purpose = "Verify ChatMessage value class accessors and toString format used by Session.getMessages()."
13+
)
14+
public class ChatMessageTest {
15+
16+
@Test
17+
public void accessors() {
18+
ChatMessage m = new ChatMessage("user", "hi");
19+
assertEquals("user", m.getRole());
20+
assertEquals("hi", m.getContent());
21+
}
22+
23+
@Test
24+
public void toStringFormat() {
25+
assertEquals("assistant: hello", new ChatMessage("assistant", "hello").toString());
26+
}
27+
}

src/test/java/net/ladenthin/llama/LlamaModelTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,34 @@ public void testCompleteAsyncCancelPropagates() throws Exception {
314314
Assert.assertNotNull(model.complete(new InferenceParameters(prefix).setNPredict(3)));
315315
}
316316

317+
/**
318+
* Regression: {@link Session} must accumulate user/assistant turns across
319+
* multiple {@link Session#send(String)} calls and expose them via
320+
* {@link Session#getMessages()}. Save/restore round-trip is exercised
321+
* separately in slot save/restore tests.
322+
*/
323+
@Test
324+
public void testSessionMultiTurn() {
325+
try (Session session = new Session(model, 0, "You are a terse assistant.",
326+
params -> params.setNPredict(8).setSeed(1))) {
327+
String r1 = session.send("Say hi.");
328+
Assert.assertNotNull(r1);
329+
String r2 = session.send("Say bye.");
330+
Assert.assertNotNull(r2);
331+
332+
java.util.List<ChatMessage> msgs = session.getMessages();
333+
// system + user + assistant + user + assistant
334+
Assert.assertEquals(5, msgs.size());
335+
Assert.assertEquals("system", msgs.get(0).getRole());
336+
Assert.assertEquals("user", msgs.get(1).getRole());
337+
Assert.assertEquals("Say hi.", msgs.get(1).getContent());
338+
Assert.assertEquals("assistant", msgs.get(2).getRole());
339+
Assert.assertEquals("user", msgs.get(3).getRole());
340+
Assert.assertEquals("Say bye.", msgs.get(3).getContent());
341+
Assert.assertEquals("assistant", msgs.get(4).getRole());
342+
}
343+
}
344+
317345
@Test
318346
public void testEmbedding() {
319347
float[] embedding = model.embed(prefix);

0 commit comments

Comments
 (0)