forked from kherud/java-llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathLlamaModel.java
More file actions
420 lines (370 loc) · 15 KB
/
Copy pathLlamaModel.java
File metadata and controls
420 lines (370 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
package net.ladenthin.llama;
import net.ladenthin.llama.args.LogFormat;
import net.ladenthin.llama.json.ChatResponseParser;
import net.ladenthin.llama.json.CompletionResponseParser;
import net.ladenthin.llama.json.RerankResponseParser;
import java.lang.annotation.Native;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
/**
* This class is a wrapper around the llama.cpp functionality.
* Upon being created, it natively allocates memory for the model context.
* Thus, this class is an {@link AutoCloseable}, in order to de-allocate the memory when it is no longer being needed.
* <p>
* The main functionality of this class is:
* <ul>
* <li>Streaming answers (and probabilities) via {@link #generate(InferenceParameters)}</li>
* <li>Creating whole responses to prompts via {@link #complete(InferenceParameters)}</li>
* <li>Creating embeddings via {@link #embed(String)} (make sure to configure {@link ModelParameters#enableEmbedding()}</li>
* <li>Accessing the tokenizer via {@link #encode(String)} and {@link #decode(int[])}</li>
* </ul>
*/
public class LlamaModel implements AutoCloseable {
static {
LlamaLoader.initialize();
}
@Native
private long ctx;
private final CompletionResponseParser completionParser = new CompletionResponseParser();
private final ChatResponseParser chatParser = new ChatResponseParser();
private final RerankResponseParser rerankParser = new RerankResponseParser();
/**
* Load with the given {@link ModelParameters}. Make sure to either set
* <ul>
* <li>{@link ModelParameters#setModel(String)}</li>
* <li>{@link ModelParameters#setModelUrl(String)}</li>
* <li>{@link ModelParameters#setHfRepo(String)}, {@link ModelParameters#setHfFile(String)}</li>
* </ul>
*
* @param parameters the set of options
* @throws LlamaException if no model could be loaded from the given file path
*/
public LlamaModel(ModelParameters parameters) {
loadModel(parameters.toArray());
}
/**
* Generate and return a whole answer with custom parameters. Note, that the prompt isn't preprocessed in any
* way, nothing like "User: ", "###Instruction", etc. is added.
*
* @param parameters the inference configuration
* @return an LLM response
*/
public String complete(InferenceParameters parameters) {
parameters.setStream(false);
int taskId = requestCompletion(parameters.toString());
String json = receiveCompletionJson(taskId);
return completionParser.parse(json).text;
}
/**
* Generate and stream outputs with custom inference parameters. Note, that the prompt isn't preprocessed in any
* way, nothing like "User: ", "###Instruction", etc. is added.
*
* @param parameters the inference configuration
* @return iterable LLM outputs
*/
public LlamaIterable generate(InferenceParameters parameters) {
return new LlamaIterable(new LlamaIterator(this, parameters));
}
/**
* Get the embedding of a string. Note, that the prompt isn't preprocessed in any way, nothing like
* "User: ", "###Instruction", etc. is added.
*
* @param prompt the string to embed
* @return an embedding float array
* @throws IllegalStateException if embedding mode was not activated (see {@link ModelParameters#enableEmbedding()})
*/
public native float[] embed(String prompt);
/**
* Tokenize a prompt given the native tokenizer
*
* @param prompt the prompt to tokenize
* @return an array of integers each representing a token id
*/
public native int[] encode(String prompt);
/**
* Convert an array of token ids to its string representation
*
* @param tokens an array of tokens
* @return the token ids decoded to a string
*/
public String decode(int[] tokens) {
byte[] bytes = decodeBytes(tokens);
return new String(bytes, StandardCharsets.UTF_8);
}
/**
* Sets a callback for native llama.cpp log messages.
* Per default, log messages are written in JSON to stdout. Note, that in text mode the callback will be also
* invoked with log messages of the GGML backend, while JSON mode can only access request log messages.
* In JSON mode, GGML messages will still be written to stdout.
* To only change the log format but keep logging to stdout, the given callback can be <code>null</code>.
* To disable logging, pass an empty callback, i.e., <code>(level, msg) {@literal ->} {}</code>.
*
* @param format the log format to use
* @param callback a method to call for log messages
*/
public static native void setLogger(LogFormat format, BiConsumer<LogLevel, String> callback);
@Override
public void close() {
delete();
}
// don't overload native methods since the C++ function names get nasty
native int requestCompletion(String params) throws LlamaException;
native String receiveCompletionJson(int taskId) throws LlamaException;
native void cancelCompletion(int taskId);
native byte[] decodeBytes(int[] tokens);
private native void loadModel(String... parameters) throws LlamaException;
private native void delete();
native void releaseTask(int taskId);
private static native byte[] jsonSchemaToGrammarBytes(String schema);
/**
* Converts a JSON schema to a grammar string usable by {@link ModelParameters#setGrammar(String)}.
*
* @param schema the JSON schema as a string
* @return the converted grammar string
*/
public static String jsonSchemaToGrammar(String schema) {
return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8);
}
/**
* Rerank the given documents against the query.
*
* @param reRank whether to sort results by score in descending order
* @param query the query string
* @param documents the documents to rank
* @return a list of document/score pairs, sorted if {@code reRank} is {@code true}
*/
public List<Pair<String, Float>> rerank(boolean reRank, String query, String... documents) {
String json = handleRerank(query, documents);
List<Pair<String, Float>> rankedDocuments = rerankParser.parse(json);
if (reRank) {
rankedDocuments.sort((a, b) -> Float.compare(b.getValue(), a.getValue()));
}
return rankedDocuments;
}
/**
* Rerank the given documents against the query, returning a {@link LlamaOutput} with scored documents
* in the probabilities map.
*
* @param query the query string
* @param documents the documents to rank
* @return a LlamaOutput with document/score pairs in the probabilities map
*/
public LlamaOutput rerank(String query, String... documents) {
String json = handleRerank(query, documents);
List<Pair<String, Float>> results = rerankParser.parse(json);
Map<String, Float> probabilities = new HashMap<>();
for (Pair<String, Float> pair : results) {
probabilities.put(pair.getKey(), pair.getValue());
}
return new LlamaOutput(query, probabilities, true, StopReason.EOS);
}
native String handleRerank(String query, String... documents) throws LlamaException;
/**
* Applies the chat template to the given inference parameters and returns the formatted string.
*
* @param parameters the inference parameters containing message configuration
* @return the formatted chat template string
*/
public String applyTemplate(InferenceParameters parameters) {
return applyTemplate(parameters.toString());
}
/**
* @param parametersJson JSON-serialized inference parameters
* @return the formatted chat template string
*/
public native String applyTemplate(String parametersJson);
/**
* Run an OpenAI-compatible chat completion. The parameters must contain a "messages" array
* in the standard OpenAI chat format (objects with "role" and "content" fields). The model's
* chat template is automatically applied.
* <p>
* Example usage:
* <pre>{@code
* List<Pair<String, String>> messages = new ArrayList<>();
* messages.add(new Pair<>("user", "What is the capital of France?"));
*
* InferenceParameters params = new InferenceParameters("")
* .setMessages("You are a helpful assistant.", messages)
* .setNPredict(128)
* .setTemperature(0.7f);
*
* String response = model.chatComplete(params);
* }</pre>
*
* @param parameters the inference parameters including messages
* @return the model's response as a JSON string containing the completion result
* @throws LlamaException if the model was loaded in embedding mode or if inference fails
*/
public String chatComplete(InferenceParameters parameters) {
parameters.setStream(false);
return handleChatCompletions(parameters.toString());
}
/**
* Run an OpenAI-compatible chat completion and return only the assistant's text content.
* This is the plain-string equivalent of {@link #chatComplete(InferenceParameters)}, which
* returns the raw OAI JSON. Use this when you want the generated text directly, the same
* way {@link #complete(InferenceParameters)} works for raw completions.
*
* @param parameters the inference parameters including messages
* @return the assistant's reply text (extracted from {@code choices[0].message.content})
* @throws LlamaException if the model was loaded in embedding mode or if inference fails
*/
public String chatCompleteText(InferenceParameters parameters) {
return chatParser.extractChoiceContent(chatComplete(parameters));
}
/**
* Stream an OpenAI-compatible chat completion token by token. The parameters must contain a
* "messages" array in the standard OpenAI chat format. The model's chat template is automatically applied.
* <p>
* Example usage:
* <pre>{@code
* List<Pair<String, String>> messages = new ArrayList<>();
* messages.add(new Pair<>("user", "Tell me a story."));
*
* InferenceParameters params = new InferenceParameters("")
* .setMessages("You are a storyteller.", messages)
* .setNPredict(128);
*
* for (LlamaOutput output : model.generateChat(params)) {
* System.out.print(output.text);
* }
* }</pre>
*
* @param parameters the inference parameters including messages
* @return iterable LLM outputs with the chat template applied
* @throws LlamaException if inference fails
*/
public LlamaIterable generateChat(InferenceParameters parameters) {
return new LlamaIterable(new LlamaIterator(this, parameters, true));
}
/**
* Run a blocking completion and return the full result as a JSON string.
* This is the JSON-in/JSON-out equivalent of {@link #complete(InferenceParameters)}.
*
* @param paramsJson JSON string with at least a "prompt" field
* @return JSON response from the server
*/
public native String handleCompletions(String paramsJson) throws LlamaException;
/**
* Run an OpenAI-compatible completion (mirrors /v1/completions endpoint).
* Returns the result in OAI format with choices array.
*
* @param paramsJson JSON string with OAI-compatible completion parameters
* @return JSON response in OAI format
*/
public native String handleCompletionsOai(String paramsJson) throws LlamaException;
/**
* Run a text infill completion with explicit prefix/suffix.
* The request JSON must contain "input_prefix" and "input_suffix" fields.
*
* @param paramsJson JSON string with infill parameters
* @return JSON response from the server
*/
public native String handleInfill(String paramsJson) throws LlamaException;
/**
* Generate embeddings for the given input. The request JSON should contain
* an "input" (OAI-compat) or "content" field.
*
* @param paramsJson JSON string with embedding request
* @param oaiCompat whether to format the response in OAI-compatible format
* @return JSON response with embedding vectors
*/
public native String handleEmbeddings(String paramsJson, boolean oaiCompat) throws LlamaException;
/**
* Tokenize text content, optionally including token piece information.
*
* @param content the text to tokenize
* @param addSpecial whether to add special tokens (BOS/EOS)
* @param withPieces whether to include token piece strings in the response
* @return JSON response with token data
*/
public native String handleTokenize(String content, boolean addSpecial, boolean withPieces) throws LlamaException;
/**
* Detokenize an array of token IDs back to text.
*
* @param tokens array of token IDs
* @return JSON response with the decoded text
*/
public native String handleDetokenize(int[] tokens) throws LlamaException;
// ------------------------------------------------------------------
// Server management
// ------------------------------------------------------------------
/**
* Get server metrics and slot information as a JSON string.
*
* @return JSON with slot data, idle/processing counts, and performance metrics
*/
public String getMetrics() {
return handleSlotAction(0, 0, null);
}
private static final com.fasterxml.jackson.databind.ObjectMapper OBJECT_MAPPER =
new com.fasterxml.jackson.databind.ObjectMapper();
/**
* Returns model metadata with typed accessors for vocab, context, embedding,
* parameter count, size, and modality support flags (vision, audio).
* <p>
* The returned {@link ModelMeta} wraps the raw JSON from the native layer.
* Call {@link ModelMeta#toString()} to re-serialize to compact JSON for use
* in {@code assertEquals}.
* </p>
*
* @return {@link ModelMeta} parsed from the native {@code model_meta()} response
* @throws LlamaException if the native call fails or the response cannot be parsed
*/
public ModelMeta getModelMeta() throws LlamaException {
try {
return new ModelMeta(OBJECT_MAPPER.readTree(getModelMetaJson()));
} catch (java.io.IOException e) {
throw new LlamaException("Failed to parse model meta JSON: " + e.getMessage());
}
}
native String getModelMetaJson() throws LlamaException;
/**
* Erase the KV cache for a specific slot.
*
* @param slotId the slot ID to erase
* @return JSON with erase result
*/
public String eraseSlot(int slotId) {
return handleSlotAction(3, slotId, null);
}
/**
* Save a slot's KV cache state to a file.
*
* @param slotId the slot ID to save
* @param filepath the file path to save to
* @return JSON with save result
*/
public String saveSlot(int slotId, String filepath) {
return handleSlotAction(1, slotId, filepath);
}
/**
* Restore a slot's KV cache state from a file.
*
* @param slotId the slot ID to restore
* @param filepath the file path to restore from
* @return JSON with restore result
*/
public String restoreSlot(int slotId, String filepath) {
return handleSlotAction(2, slotId, filepath);
}
/**
* Configure runtime inference parameters.
* Accepts a JSON string with optional keys:
* <ul>
* <li>"slot_prompt_similarity" (float, 0.0-1.0)</li>
* <li>"n_threads" (int, > 0)</li>
* <li>"n_threads_batch" (int, > 0)</li>
* </ul>
*
* @param configJson JSON configuration string
* @return true if configuration was applied successfully
*/
public native boolean configureParallelInference(String configJson) throws LlamaException;
native String handleSlotAction(int action, int slotId, String filename) throws LlamaException;
native String handleChatCompletions(String params) throws LlamaException;
native int requestChatCompletion(String params) throws LlamaException;
}