forked from bernardladenthin/java-llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathParameterJsonSerializer.java
More file actions
303 lines (282 loc) · 12.4 KB
/
Copy pathParameterJsonSerializer.java
File metadata and controls
303 lines (282 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
// SPDX-FileCopyrightText: 2023-2025 Konstantin Herud
//
// SPDX-License-Identifier: MIT
package net.ladenthin.llama.parameters;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import net.ladenthin.llama.args.Sampler;
import net.ladenthin.llama.value.ChatMessage;
import net.ladenthin.llama.value.ContentPart;
import net.ladenthin.llama.value.Pair;
import org.jspecify.annotations.Nullable;
/**
* Pure JSON builders for inference request parameters.
*
* <p>All methods are stateless and have zero dependency on JNI, native libraries, or llama
* model state — they can be tested with plain Java values alone (see
* {@code ParameterJsonSerializerTest}).
*
* <p>Methods return Jackson {@link ArrayNode} or {@link ObjectNode}. Callers that need a JSON
* string (e.g. callers in {@code JsonParameters}) call {@code node.toString()}.
*
* <p>This class replaces hand-rolled {@code StringBuilder} loops and the
* {@code org.json}-derived {@code toJsonString()} escaper previously embedded in
* {@code JsonParameters}.
*/
public class ParameterJsonSerializer {
/** Creates a new {@link ParameterJsonSerializer}. */
public ParameterJsonSerializer() {}
/** Shared Jackson mapper; thread-safe and reused across all instances. */
public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
// ------------------------------------------------------------------
// String escaping
// ------------------------------------------------------------------
/**
* Serialize a Java string to a quoted, properly escaped JSON string literal
* (e.g. {@code "hello\nworld"} → {@code "\"hello\\nworld\""}).
* Returns {@code "null"} for a {@code null} input.
*
* <p>Replaces the hand-rolled {@code toJsonString()} method in {@code JsonParameters}.
*
* @param value the Java string to serialize, or {@code null}
* @return a JSON string literal, or {@code "null"} if the input is {@code null}
*/
public String toJsonString(String value) {
if (value == null) return "null";
try {
return OBJECT_MAPPER.writeValueAsString(value);
} catch (JsonProcessingException e) {
return "null";
}
}
// ------------------------------------------------------------------
// Message array
// ------------------------------------------------------------------
/**
* Build an OAI-compatible {@code messages} array node.
*
* <p>An optional system message is prepended when non-null and non-empty.
* Each message in {@code messages} must have role {@code "user"} or {@code "assistant"}.
*
* @param systemMessage optional system prompt; skipped when {@code null} or empty
* @param messages list of user/assistant message pairs (role as key, content as value)
* @return a Jackson {@link ArrayNode} of {@code {"role", "content"}} message objects
* @throws IllegalArgumentException if any message has an invalid role
*/
public ArrayNode buildMessages(@Nullable String systemMessage, List<Pair<String, String>> messages) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
if (systemMessage != null && !systemMessage.isEmpty()) {
ObjectNode sys = OBJECT_MAPPER.createObjectNode();
sys.put("role", "system");
sys.put("content", systemMessage);
arr.add(sys);
}
for (Pair<String, String> message : messages) {
String role = message.getKey();
String content = message.getValue();
if (!"user".equals(role) && !"assistant".equals(role)) {
throw new IllegalArgumentException("Invalid role: " + role + ". Role must be 'user' or 'assistant'.");
}
ObjectNode msg = OBJECT_MAPPER.createObjectNode();
msg.put("role", role);
msg.put("content", content);
arr.add(msg);
}
return arr;
}
/**
* Multimodal-capable variant of {@link #buildMessages(String, List)}. Accepts
* {@link net.ladenthin.llama.value.ChatMessage} objects directly so messages with non-null
* {@link net.ladenthin.llama.value.ChatMessage#getParts()} can be serialized as the OAI-compatible
* array-form {@code content} the upstream {@code mtmd} pipeline expects.
* Plain text messages still emit the legacy string-form {@code content}.
*
* @param messages messages in order; must not be {@code null}
* @return a Jackson {@link ArrayNode} ready for the {@code messages} request field
*/
public ArrayNode buildMessages(List<ChatMessage> messages) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (ChatMessage message : messages) {
ObjectNode msg = OBJECT_MAPPER.createObjectNode();
msg.put("role", message.getRole());
if (message.hasParts()) {
ArrayNode parts = OBJECT_MAPPER.createArrayNode();
for (ContentPart p : message.getParts()
.orElseThrow(() -> new IllegalStateException("hasParts() was true but getParts() was empty"))) {
ObjectNode part = OBJECT_MAPPER.createObjectNode();
if (p.getType() == ContentPart.Type.TEXT) {
part.put("type", "text");
final String text = p.getText();
part.put("text", text != null ? text : "");
} else if (p.getType() == ContentPart.Type.INPUT_AUDIO) {
part.put("type", "input_audio");
ObjectNode inputAudio = OBJECT_MAPPER.createObjectNode();
final String data = p.getAudioData();
final String format = p.getAudioFormat();
inputAudio.put("data", data != null ? data : "");
inputAudio.put("format", format != null ? format : "wav");
part.set("input_audio", inputAudio);
} else {
part.put("type", "image_url");
ObjectNode imageUrl = OBJECT_MAPPER.createObjectNode();
final String url = p.getImageUrl();
imageUrl.put("url", url != null ? url : "");
part.set("image_url", imageUrl);
}
parts.add(part);
}
msg.set("content", parts);
} else {
msg.put("content", message.getContent());
}
arr.add(msg);
}
return arr;
}
// ------------------------------------------------------------------
// Simple array builders
// ------------------------------------------------------------------
/**
* Build a JSON string array from the given stop strings
* (e.g. {@code ["<|endoftext|>", "\n"]}).
*
* @param stops one or more stop strings
* @return a Jackson {@link ArrayNode} of stop string values
*/
public ArrayNode buildStopStrings(String... stops) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (String stop : stops) arr.add(stop);
return arr;
}
/**
* Build a JSON string array from the given sampler sequence
* (e.g. {@code ["top_k", "top_p", "temperature"]}).
*
* @param samplers one or more samplers in the desired order
* @return a Jackson {@link ArrayNode} of sampler name strings
*/
public ArrayNode buildSamplers(Sampler... samplers) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (Sampler sampler : samplers) {
arr.add(sampler.getArgValue());
}
return arr;
}
/**
* Build a JSON integer array from a primitive {@code int[]}
* (used for penalty-prompt token sequences).
*
* @param values the token IDs to include
* @return a Jackson {@link ArrayNode} of integer values
*/
public ArrayNode buildIntArray(int... values) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (int v : values) arr.add(v);
return arr;
}
// ------------------------------------------------------------------
// Logit-bias pair arrays — [[key, value], ...]
// ------------------------------------------------------------------
/**
* Build a logit-bias array for integer token IDs:
* {@code [[15043, 1.0], [50256, -0.5]]}.
*
* @param biases map from token ID to logit bias value
* @return a Jackson {@link ArrayNode} of {@code [tokenId, biasValue]} pairs
*/
public ArrayNode buildTokenIdBiasArray(Map<Integer, Float> biases) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (Map.Entry<Integer, Float> entry : biases.entrySet()) {
ArrayNode pair = OBJECT_MAPPER.createArrayNode();
pair.add(entry.getKey());
pair.add(entry.getValue());
arr.add(pair);
}
return arr;
}
/**
* Build a logit-bias array for string tokens:
* {@code [["Hello", 1.0], [" world", -0.5]]}.
*
* @param biases map from token string to logit bias value
* @return a Jackson {@link ArrayNode} of {@code ["token", biasValue]} pairs
*/
public ArrayNode buildTokenStringBiasArray(Map<String, Float> biases) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (Map.Entry<String, Float> entry : biases.entrySet()) {
ArrayNode pair = OBJECT_MAPPER.createArrayNode();
pair.add(entry.getKey());
pair.add(entry.getValue());
arr.add(pair);
}
return arr;
}
/**
* Build a disable-token array for integer token IDs:
* {@code [[15043, false], [50256, false]]}.
*
* @param ids collection of integer token IDs to disable
* @return a Jackson {@link ArrayNode} of {@code [tokenId, false]} pairs
*/
public ArrayNode buildDisableTokenIdArray(Collection<Integer> ids) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (Integer id : ids) {
ArrayNode pair = OBJECT_MAPPER.createArrayNode();
pair.add(id);
pair.add(false);
arr.add(pair);
}
return arr;
}
/**
* Build a disable-token array for string tokens:
* {@code [["Hello", false], [" world", false]]}.
*
* @param tokens collection of token strings to disable
* @return a Jackson {@link ArrayNode} of {@code ["token", false]} pairs
*/
public ArrayNode buildDisableTokenStringArray(Collection<String> tokens) {
ArrayNode arr = OBJECT_MAPPER.createArrayNode();
for (String token : tokens) {
ArrayNode pair = OBJECT_MAPPER.createArrayNode();
pair.add(token);
pair.add(false);
arr.add(pair);
}
return arr;
}
// ------------------------------------------------------------------
// Object with pre-serialized JSON values
// ------------------------------------------------------------------
/**
* Build a JSON object where each map value is a <em>pre-serialized JSON string</em>
* (not a plain Java string). For example, a map entry {@code ("enable_thinking", "true")}
* produces {@code {"enable_thinking": true}}, not {@code {"enable_thinking": "true"}}.
*
* <p>Used for {@code chat_template_kwargs} which stores raw JSON values.
* If a value cannot be parsed as JSON, it is stored as a JSON string literal.
*
* @param map map of key to pre-serialized JSON value strings
* @return a Jackson {@link ObjectNode} with each value embedded as a parsed JSON node
*/
public ObjectNode buildRawValueObject(Map<String, String> map) {
ObjectNode node = OBJECT_MAPPER.createObjectNode();
for (Map.Entry<String, String> entry : map.entrySet()) {
try {
JsonNode val = OBJECT_MAPPER.readTree(entry.getValue());
node.set(entry.getKey(), val);
} catch (IOException e) {
node.put(entry.getKey(), entry.getValue());
}
}
return node;
}
}