forked from kherud/java-llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathSession.java
More file actions
207 lines (190 loc) · 9.66 KB
/
Copy pathSession.java
File metadata and controls
207 lines (190 loc) · 9.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
//
// SPDX-License-Identifier: MIT
package net.ladenthin.llama;
import java.util.List;
import java.util.function.UnaryOperator;
import lombok.ToString;
import net.ladenthin.llama.parameters.InferenceParameters;
import net.ladenthin.llama.value.ChatMessage;
import net.ladenthin.llama.value.Pair;
import org.jspecify.annotations.Nullable;
/**
* Thin multi-turn conversation wrapper over a {@link LlamaModel} slot. Maintains an
* accumulating list of {@link net.ladenthin.llama.value.ChatMessage} turns and forwards each {@link #send(String)}
* to the underlying chat-completion API with the full transcript so far. KV-cache state
* for the bound slot can be persisted via {@link #save(String)} and restored with
* {@link #restore(String)}, which delegate to {@link LlamaModel#saveSlot(int, String)}
* and {@link LlamaModel#restoreSlot(int, String)}.
* <p>
* Thread-safety: all public methods are serialized on a private intrinsic lock, so
* concurrent {@link #send(String)} calls from multiple threads produce a well-formed
* transcript with strict user/assistant alternation. {@link #stream(String)} sets a
* "streaming in progress" flag and returns the iterator without holding the lock;
* while that flag is set, {@link #send(String)}, a second {@link #stream(String)},
* {@link #save(String)}, and {@link #restore(String)} fail-fast with
* {@link IllegalStateException} until the caller invokes
* {@link #commitStreamedReply(String)}.
* </p>
*
* <p>{@code toString} is generated by Lombok over the slot id and the
* {@link SessionState} (which renders the transcript and streaming flag). The owning
* {@link LlamaModel} is excluded because its {@code toString} would render native
* state. The {@code paramsCustomizer} {@link UnaryOperator} is excluded because lambda
* {@code toString} is the implementation hash, not useful in logs.
* {@code equals}/{@code hashCode} are intentionally NOT generated: a session is a
* mutable lifecycle handle managed by identity.</p>
*/
@ToString
public final class Session implements AutoCloseable {
// Owning model — its toString would recursively render native state.
@ToString.Exclude
private final LlamaModel model;
private final int slotId;
/**
* The lock-guarded streaming-flag + transcript state machine. Extracted to
* {@link SessionState} so its concurrency contract (the two-phase commit and the
* streaming guard) is testable without the native model; see that class and the
* {@code SessionStateInterleavingTest}. This {@code Session} only adds the model
* calls, injected as callbacks that {@link SessionState} runs under its lock.
*/
private final SessionState state;
// Lambda UnaryOperator — toString is the implementation hash, not useful in logs.
@ToString.Exclude
private final @Nullable UnaryOperator<InferenceParameters> paramsCustomizer;
/**
* Create a session bound to a specific slot id, with an optional system prompt
* applied to every {@link #send(String)} call.
*
* @param model the underlying model
* @param slotId the slot id used by {@link #save(String)} / {@link #restore(String)}
* @param systemMessage optional system prompt (may be {@code null} or empty)
*/
public Session(LlamaModel model, int slotId, @Nullable String systemMessage) {
this(model, slotId, systemMessage, null);
}
/**
* Create a session with a customizer that transforms the
* {@link net.ladenthin.llama.parameters.InferenceParameters} for every call (e.g. {@code p -> p.withTemperature(0.7f).withNPredict(64)}).
* Because {@link net.ladenthin.llama.parameters.InferenceParameters} is immutable, the customiser must return
* the transformed instance — it cannot mutate the input.
*
* @param model the underlying model
* @param slotId the slot id; must be non-negative (a session is pinned to one concrete slot
* for both inference and {@link #save(String)} / {@link #restore(String)} / {@link #close()})
* @param systemMessage optional system prompt
* @param paramsCustomizer applied to each request's parameters; may be {@code null}
* @throws IllegalArgumentException if {@code slotId} is negative
*/
public Session(
LlamaModel model,
int slotId,
@Nullable String systemMessage,
@Nullable UnaryOperator<InferenceParameters> paramsCustomizer) {
// Validate here, not per request: every send()/stream() pins this slot id (see
// buildParams), and the slot also backs save()/restore()/close(). A negative id is
// meaningless for those, so reject it up front with a clear message rather than letting
// InferenceParameters.withSlotId throw on the first inference call.
if (slotId < 0) {
throw new IllegalArgumentException("slotId must be non-negative, was " + slotId);
}
this.model = model;
this.slotId = slotId;
this.state = new SessionState(slotId, systemMessage);
this.paramsCustomizer = paramsCustomizer;
}
/**
* Send a user message and return the assistant's text reply, appending both to the transcript.
*
* @param userMessage the user turn to append before invoking the model
* @return the assistant's reply text
*/
public String send(String userMessage) {
// Two-phase commit lives in SessionState.send(...): it guards against an
// in-progress stream, builds the wire-format with the pending user turn, runs
// the model call below under the lock, and on success commits BOTH turns
// atomically. On model failure nothing is committed — no rollback needed.
return state.send(
userMessage,
(systemMessage, wireMessages) -> model.chatCompleteText(buildParams(systemMessage, wireMessages)));
}
/**
* Streaming variant of {@link #send(String)}. The returned iterable yields chunks of
* the assistant reply; consume it fully (or via try-with-resources) before calling
* {@link #send(String)} again, because the assistant turn is only appended to the
* transcript when the caller invokes {@link #commitStreamedReply(String)}.
*
* @param userMessage the user turn to append before starting the stream
* @return a {@link LlamaIterable} that yields assistant reply chunks
*/
public LlamaIterable stream(String userMessage) {
// SessionState.beginStream(...) guards against an in-progress stream, runs the
// model call below under the lock, and on success commits the user turn and
// marks streaming active; the assistant turn is committed separately by
// commitStreamedReply(...).
return state.beginStream(
userMessage,
(systemMessage, wireMessages) -> model.generateChat(buildParams(systemMessage, wireMessages)));
}
/**
* Record an assistant reply that was produced by a previous {@link #stream(String)}
* call. Called by the caller after it has accumulated the streamed text.
*
* @param assistantText the assistant text accumulated from a prior {@link #stream(String)} call
*/
public void commitStreamedReply(String assistantText) {
state.commitStreamedReply(assistantText);
}
/**
* Save this session's slot KV cache to {@code filepath}.
*
* @param filepath destination file path passed to {@link LlamaModel#saveSlot(int, String)}
* @return the JSON response from the native save action
*/
public String save(String filepath) {
return state.runWhenNotStreaming("save", () -> model.saveSlot(slotId, filepath));
}
/**
* Restore this session's slot KV cache from {@code filepath}.
*
* @param filepath source file path passed to {@link LlamaModel#restoreSlot(int, String)}
* @return the JSON response from the native restore action
*/
public String restore(String filepath) {
return state.runWhenNotStreaming("restore", () -> model.restoreSlot(slotId, filepath));
}
/**
* Transcript accessor.
* @return the accumulated transcript so far, in order, including the system message if any
*/
public List<ChatMessage> getMessages() {
return state.snapshot();
}
/** Erase the bound slot's KV cache. Does not modify the in-memory transcript. */
@Override
public void close() {
state.runUnderLock(() -> model.eraseSlot(slotId));
}
/**
* Build inference parameters from the system message and the wire-format messages
* supplied by {@link SessionState} (the committed turns plus the pending user
* turn), applying the optional {@code paramsCustomizer}. The transcript itself is
* never mutated here; {@link SessionState} commits turns only after the model call
* returns successfully.
*
* @param systemMessage the system prompt, or {@code null} when none was configured
* @param wireMessages the committed turns plus the pending user turn
* @return inference parameters carrying the system message + wire messages
*/
private InferenceParameters buildParams(@Nullable String systemMessage, List<Pair<String, String>> wireMessages) {
InferenceParameters params = InferenceParameters.empty()
.withMessages(systemMessage, wireMessages)
.withCachePrompt(true);
if (paramsCustomizer != null) {
params = paramsCustomizer.apply(params);
}
// Apply last: a Session must never drift away from the slot used by
// save(), restore(), and close(), even if a customizer supplies another id.
return params.withSlotId(slotId);
}
}