1717 * {@link #restore(String)}, which delegate to {@link LlamaModel#saveSlot(int, String)}
1818 * and {@link LlamaModel#restoreSlot(int, String)}.
1919 * <p>
20- * This wrapper is intentionally not thread-safe; callers must serialize access to a
21- * single {@code Session} instance. Concurrency support is a follow-up (M-effort) item.
20+ * Thread-safety: all public methods are serialized on a private intrinsic lock, so
21+ * concurrent {@link #send(String)} calls from multiple threads produce a well-formed
22+ * transcript with strict user/assistant alternation. {@link #stream(String)} sets a
23+ * "streaming in progress" flag and returns the iterator without holding the lock;
24+ * while that flag is set, {@link #send(String)}, a second {@link #stream(String)},
25+ * {@link #save(String)}, and {@link #restore(String)} fail-fast with
26+ * {@link IllegalStateException} until the caller invokes
27+ * {@link #commitStreamedReply(String)}.
2228 * </p>
2329 */
2430public final class Session implements AutoCloseable {
@@ -28,6 +34,8 @@ public final class Session implements AutoCloseable {
2834 private final String systemMessage ;
2935 private final List <Pair <String , String >> turns = new ArrayList <Pair <String , String >>();
3036 private final Consumer <InferenceParameters > paramsCustomizer ;
37+ private final Object lock = new Object ();
38+ private boolean streamingActive ;
3139
3240 /**
3341 * Create a session bound to a specific slot id, with an optional system prompt
@@ -65,11 +73,22 @@ public Session(LlamaModel model, int slotId, String systemMessage,
6573 * @return the assistant's reply text
6674 */
6775 public String send (String userMessage ) {
68- turns .add (new Pair <String , String >("user" , userMessage ));
69- InferenceParameters params = buildParams ();
70- String reply = model .chatCompleteText (params );
71- turns .add (new Pair <String , String >("assistant" , reply ));
72- return reply ;
76+ synchronized (lock ) {
77+ if (streamingActive ) {
78+ throw new IllegalStateException (
79+ "stream in progress; call commitStreamedReply(...) before send(...)" );
80+ }
81+ turns .add (new Pair <String , String >("user" , userMessage ));
82+ InferenceParameters params = buildParams ();
83+ try {
84+ String reply = model .chatCompleteText (params );
85+ turns .add (new Pair <String , String >("assistant" , reply ));
86+ return reply ;
87+ } catch (RuntimeException e ) {
88+ turns .remove (turns .size () - 1 );
89+ throw e ;
90+ }
91+ }
7392 }
7493
7594 /**
@@ -82,8 +101,21 @@ public String send(String userMessage) {
82101 * @return a {@link LlamaIterable} that yields assistant reply chunks
83102 */
84103 public LlamaIterable stream (String userMessage ) {
85- turns .add (new Pair <String , String >("user" , userMessage ));
86- return model .generateChat (buildParams ());
104+ synchronized (lock ) {
105+ if (streamingActive ) {
106+ throw new IllegalStateException (
107+ "stream in progress; call commitStreamedReply(...) before stream(...)" );
108+ }
109+ turns .add (new Pair <String , String >("user" , userMessage ));
110+ try {
111+ LlamaIterable iterable = model .generateChat (buildParams ());
112+ streamingActive = true ;
113+ return iterable ;
114+ } catch (RuntimeException e ) {
115+ turns .remove (turns .size () - 1 );
116+ throw e ;
117+ }
118+ }
87119 }
88120
89121 /**
@@ -93,7 +125,14 @@ public LlamaIterable stream(String userMessage) {
93125 * @param assistantText the assistant text accumulated from a prior {@link #stream(String)} call
94126 */
95127 public void commitStreamedReply (String assistantText ) {
96- turns .add (new Pair <String , String >("assistant" , assistantText ));
128+ synchronized (lock ) {
129+ if (!streamingActive ) {
130+ throw new IllegalStateException (
131+ "no stream in progress; call stream(...) first" );
132+ }
133+ turns .add (new Pair <String , String >("assistant" , assistantText ));
134+ streamingActive = false ;
135+ }
97136 }
98137
99138 /**
@@ -103,7 +142,13 @@ public void commitStreamedReply(String assistantText) {
103142 * @return the JSON response from the native save action
104143 */
105144 public String save (String filepath ) {
106- return model .saveSlot (slotId , filepath );
145+ synchronized (lock ) {
146+ if (streamingActive ) {
147+ throw new IllegalStateException (
148+ "stream in progress; call commitStreamedReply(...) before save(...)" );
149+ }
150+ return model .saveSlot (slotId , filepath );
151+ }
107152 }
108153
109154 /**
@@ -113,28 +158,38 @@ public String save(String filepath) {
113158 * @return the JSON response from the native restore action
114159 */
115160 public String restore (String filepath ) {
116- return model .restoreSlot (slotId , filepath );
161+ synchronized (lock ) {
162+ if (streamingActive ) {
163+ throw new IllegalStateException (
164+ "stream in progress; call commitStreamedReply(...) before restore(...)" );
165+ }
166+ return model .restoreSlot (slotId , filepath );
167+ }
117168 }
118169
119170 /**
120171 * Transcript accessor.
121172 * @return the accumulated transcript so far, in order, including the system message if any
122173 */
123174 public List <ChatMessage > getMessages () {
124- List <ChatMessage > out = new ArrayList <ChatMessage >(turns .size () + 1 );
125- if (systemMessage != null && !systemMessage .isEmpty ()) {
126- out .add (new ChatMessage ("system" , systemMessage ));
127- }
128- for (Pair <String , String > p : turns ) {
129- out .add (new ChatMessage (p .getKey (), p .getValue ()));
175+ synchronized (lock ) {
176+ List <ChatMessage > out = new ArrayList <ChatMessage >(turns .size () + 1 );
177+ if (systemMessage != null && !systemMessage .isEmpty ()) {
178+ out .add (new ChatMessage ("system" , systemMessage ));
179+ }
180+ for (Pair <String , String > p : turns ) {
181+ out .add (new ChatMessage (p .getKey (), p .getValue ()));
182+ }
183+ return Collections .unmodifiableList (out );
130184 }
131- return Collections .unmodifiableList (out );
132185 }
133186
134187 /** Erase the bound slot's KV cache. Does not modify the in-memory transcript. */
135188 @ Override
136189 public void close () {
137- model .eraseSlot (slotId );
190+ synchronized (lock ) {
191+ model .eraseSlot (slotId );
192+ }
138193 }
139194
140195 private InferenceParameters buildParams () {
0 commit comments