Skip to content

Commit de457b2

Browse files
committed
Add completeBatch / chatBatch parallel dispatch (§2.4)
Three new methods on LlamaModel that hand a list of requests to the native scheduler at once and collect results in input order: - completeBatch(List<InferenceParameters>) -> List<String> - completeBatchWithStats(List<InferenceParameters>) -> List<CompletionResult> - chatBatch(List<ChatRequest>) -> List<ChatResponse> Implementation reuses the existing CompletableFuture wrappers (completeAsync, supplyAsync(() -> completeWithStats/chat)) and joins them all in input order. The native worker thread runs the upstream slot scheduler, which dispatches tasks across however many slots ModelParameters.setParallel(N) was configured with. With the default N=1 the batch still works correctly, just sequentially. No JNI changes — the upstream scheduler already supports parallel slot execution; this surfaces it as a typed Java API. Three model-gated tests in LlamaModelTest exercise the order-preserving contract and per-result Usage population. mvn javadoc:jar BUILD SUCCESS, no new warnings. https://claude.ai/code/session_01R4ZrEy3ptJDLuUgUKuM4Gy
1 parent c529499 commit de457b2

2 files changed

Lines changed: 106 additions & 0 deletions

File tree

src/main/java/net/ladenthin/llama/LlamaModel.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,67 @@ public CompletionResult completeWithStats(InferenceParameters parameters) {
120120
* @param token cancellation handle; {@link CancellationToken#cancel()} aborts the loop
121121
* @return the text generated up to the point of stop or cancellation
122122
*/
123+
/**
124+
* Dispatch a list of completion requests in parallel and return the generated texts
125+
* in the same order. Each request is sent immediately; the native scheduler dispatches
126+
* tasks across whatever slot count {@link ModelParameters#setParallel(int)} was
127+
* configured with. With a default single-slot model the requests still run, but
128+
* sequentially.
129+
*
130+
* @param requests the list of inference parameter blocks (must be distinct instances)
131+
* @return the generated texts in input order
132+
*/
133+
public java.util.List<String> completeBatch(java.util.List<InferenceParameters> requests) {
134+
java.util.List<CompletableFuture<String>> futures = new java.util.ArrayList<CompletableFuture<String>>(requests.size());
135+
for (InferenceParameters req : requests) {
136+
futures.add(completeAsync(req));
137+
}
138+
java.util.List<String> out = new java.util.ArrayList<String>(futures.size());
139+
for (CompletableFuture<String> f : futures) {
140+
out.add(f.join());
141+
}
142+
return out;
143+
}
144+
145+
/**
146+
* Like {@link #completeBatch(java.util.List)} but each result carries
147+
* {@link CompletionResult}'s typed Usage, Timings, logprobs, and stop reason.
148+
*
149+
* @param requests the list of inference parameter blocks (must be distinct instances)
150+
* @return parsed completion results in input order
151+
*/
152+
public java.util.List<CompletionResult> completeBatchWithStats(java.util.List<InferenceParameters> requests) {
153+
java.util.List<CompletableFuture<CompletionResult>> futures = new java.util.ArrayList<CompletableFuture<CompletionResult>>(requests.size());
154+
for (final InferenceParameters req : requests) {
155+
futures.add(CompletableFuture.supplyAsync(() -> completeWithStats(req)));
156+
}
157+
java.util.List<CompletionResult> out = new java.util.ArrayList<CompletionResult>(futures.size());
158+
for (CompletableFuture<CompletionResult> f : futures) {
159+
out.add(f.join());
160+
}
161+
return out;
162+
}
163+
164+
/**
165+
* Dispatch a list of typed chat requests in parallel and return the parsed responses
166+
* in the same order. Requires {@link ModelParameters#setParallel(int)} &gt; 1 for
167+
* actual parallelism; otherwise the calls run sequentially on the single slot.
168+
*
169+
* @param requests the typed chat requests (must be distinct instances)
170+
* @return parsed responses in input order
171+
*/
172+
public java.util.List<ChatResponse> chatBatch(java.util.List<ChatRequest> requests) {
173+
java.util.List<CompletableFuture<ChatResponse>> futures = new java.util.ArrayList<CompletableFuture<ChatResponse>>(requests.size());
174+
for (final ChatRequest req : requests) {
175+
futures.add(CompletableFuture.supplyAsync(() -> chat(req)));
176+
}
177+
java.util.List<ChatResponse> out = new java.util.ArrayList<ChatResponse>(futures.size());
178+
for (CompletableFuture<ChatResponse> f : futures) {
179+
out.add(f.join());
180+
}
181+
return out;
182+
}
183+
123184
/**
124185
* Asynchronous variant of {@link #complete(InferenceParameters)}. Runs the inference on
125186
* the common {@link java.util.concurrent.ForkJoinPool} so it does not block the calling

src/test/java/net/ladenthin/llama/LlamaModelTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,51 @@ public void testChatWithToolsLoopShortCircuits() {
389389
Assert.assertFalse(r.getChoices().isEmpty());
390390
}
391391

392+
/**
393+
* Regression: {@link LlamaModel#completeBatch(java.util.List)} returns results in
394+
* the same order as the input list, with one non-null text per request. The shared
395+
* test model is single-slot, so this primarily exercises the parallel dispatch and
396+
* order-preservation contract, not actual parallel throughput.
397+
*/
398+
@Test
399+
public void testCompleteBatch() {
400+
java.util.List<InferenceParameters> requests = java.util.Arrays.asList(
401+
new InferenceParameters(prefix).setNPredict(3).setSeed(1),
402+
new InferenceParameters(prefix).setNPredict(3).setSeed(2),
403+
new InferenceParameters(prefix).setNPredict(3).setSeed(3));
404+
java.util.List<String> results = model.completeBatch(requests);
405+
Assert.assertEquals(3, results.size());
406+
for (String r : results) {
407+
Assert.assertNotNull(r);
408+
}
409+
}
410+
411+
@Test
412+
public void testCompleteBatchWithStats() {
413+
java.util.List<InferenceParameters> requests = java.util.Arrays.asList(
414+
new InferenceParameters(prefix).setNPredict(3).setSeed(1),
415+
new InferenceParameters(prefix).setNPredict(3).setSeed(2));
416+
java.util.List<CompletionResult> results = model.completeBatchWithStats(requests);
417+
Assert.assertEquals(2, results.size());
418+
for (CompletionResult r : results) {
419+
Assert.assertNotNull(r);
420+
Assert.assertTrue("expected non-zero total tokens, got " + r.getUsage().getTotalTokens(),
421+
r.getUsage().getTotalTokens() > 0);
422+
}
423+
}
424+
425+
@Test
426+
public void testChatBatch() {
427+
java.util.List<ChatRequest> requests = java.util.Arrays.asList(
428+
new ChatRequest().addMessage("user", "Say hi.").setInferenceCustomizer(p -> p.setNPredict(4).setSeed(1)),
429+
new ChatRequest().addMessage("user", "Say bye.").setInferenceCustomizer(p -> p.setNPredict(4).setSeed(2)));
430+
java.util.List<ChatResponse> results = model.chatBatch(requests);
431+
Assert.assertEquals(2, results.size());
432+
for (ChatResponse r : results) {
433+
Assert.assertFalse(r.getChoices().isEmpty());
434+
}
435+
}
436+
392437
@Test
393438
public void testEmbedding() {
394439
float[] embedding = model.embed(prefix);

0 commit comments

Comments
 (0)