diff --git a/.github/validate-models.bat b/.github/validate-models.bat index f5ca46fc..76fd4e5a 100644 --- a/.github/validate-models.bat +++ b/.github/validate-models.bat @@ -9,7 +9,7 @@ REM GGUF files start with magic bytes: 0x47 0x47 0x55 0x46 ("GGUF") setlocal enabledelayedexpansion -set "MODELS=models\codellama-7b.Q2_K.gguf" "models\jina-reranker-v1-tiny-en-Q4_0.gguf" "models\AMD-Llama-135m-code.Q2_K.gguf" "models\Qwen3-0.6B-Q4_K_M.gguf" +set "MODELS=models\codellama-7b.Q2_K.gguf" "models\jina-reranker-v1-tiny-en-Q4_0.gguf" "models\AMD-Llama-135m-code.Q2_K.gguf" "models\Qwen3-0.6B-Q4_K_M.gguf" "models\Qwen2.5-1.5B-Instruct-Q4_K_M.gguf" REM Vision GGUFs are validated only when present (the Windows job downloads REM them too, but the validation step must not fail when a future job opts diff --git a/.github/validate-models.sh b/.github/validate-models.sh index 6f8ef46e..3df964f4 100755 --- a/.github/validate-models.sh +++ b/.github/validate-models.sh @@ -15,6 +15,7 @@ MODELS=( "models/jina-reranker-v1-tiny-en-Q4_0.gguf" "models/AMD-Llama-135m-code.Q2_K.gguf" "models/Qwen3-0.6B-Q4_K_M.gguf" + "models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf" ) # Optional GGUFs validated only when present so jobs that do not download diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3596ec3b..a9c2b3cf 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -25,6 +25,8 @@ env: DRAFT_MODEL_NAME: "AMD-Llama-135m-code.Q2_K.gguf" REASONING_MODEL_URL: "https://huggingface.co/unsloth/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q4_K_M.gguf" REASONING_MODEL_NAME: "Qwen3-0.6B-Q4_K_M.gguf" + TOOL_MODEL_URL: "https://huggingface.co/bartowski/Qwen2.5-1.5B-Instruct-GGUF/resolve/main/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf" + TOOL_MODEL_NAME: "Qwen2.5-1.5B-Instruct-Q4_K_M.gguf" NOMIC_EMBED_MODEL_URL: "https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.f16.gguf" NOMIC_EMBED_MODEL_NAME: "nomic-embed-text-v1.5.f16.gguf" # Vision model + mmproj for MultimodalIntegrationTest (issues #103 / #34). @@ -405,6 +407,8 @@ jobs: run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + - name: Download tool-calling model + run: curl -L --fail --retry 5 --retry-all-errors ${TOOL_MODEL_URL} --create-dirs -o models/${TOOL_MODEL_NAME} - name: Download nomic embedding model (issue #98 regression) run: curl -L --fail --retry 5 --retry-all-errors ${NOMIC_EMBED_MODEL_URL} --create-dirs -o models/${NOMIC_EMBED_MODEL_NAME} - name: Download vision model (issues #103 / #34) @@ -428,6 +432,7 @@ jobs: - name: Run tests run: | mvn -e --no-transfer-progress -P jcstress test \ + -Dnet.ladenthin.llama.tool.model=models/${TOOL_MODEL_NAME} \ -Dnet.ladenthin.llama.nomic.path=models/${NOMIC_EMBED_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.model=models/${VISION_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.mmproj=models/${VISION_MMPROJ_NAME} \ @@ -526,6 +531,8 @@ jobs: run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + - name: Download tool-calling model + run: curl -L --fail --retry 5 --retry-all-errors ${TOOL_MODEL_URL} --create-dirs -o models/${TOOL_MODEL_NAME} - name: Download vision model (issues #103 / #34) run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} - name: Download vision mmproj @@ -545,6 +552,7 @@ jobs: - name: Run tests run: | mvn -e --no-transfer-progress -Dnet.ladenthin.llama.test.ngl=0 test \ + -Dnet.ladenthin.llama.tool.model=models/${TOOL_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.model=models/${VISION_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.mmproj=models/${VISION_MMPROJ_NAME} \ -Dnet.ladenthin.llama.vision.image=${VISION_IMAGE_PATH} @@ -590,6 +598,8 @@ jobs: run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + - name: Download tool-calling model + run: curl -L --fail --retry 5 --retry-all-errors ${TOOL_MODEL_URL} --create-dirs -o models/${TOOL_MODEL_NAME} - name: Download vision model (issues #103 / #34) run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} - name: Download vision mmproj @@ -609,6 +619,7 @@ jobs: - name: Run tests run: | mvn -e --no-transfer-progress test \ + -Dnet.ladenthin.llama.tool.model=models/${TOOL_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.model=models/${VISION_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.mmproj=models/${VISION_MMPROJ_NAME} \ -Dnet.ladenthin.llama.vision.image=${VISION_IMAGE_PATH} @@ -654,6 +665,8 @@ jobs: run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + - name: Download tool-calling model + run: curl -L --fail --retry 5 --retry-all-errors ${TOOL_MODEL_URL} --create-dirs -o models/${TOOL_MODEL_NAME} - name: Download vision model (issues #103 / #34) run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} - name: Download vision mmproj @@ -673,6 +686,7 @@ jobs: - name: Run tests run: | mvn -e --no-transfer-progress test \ + -Dnet.ladenthin.llama.tool.model=models/${TOOL_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.model=models/${VISION_MODEL_NAME} \ -Dnet.ladenthin.llama.vision.mmproj=models/${VISION_MMPROJ_NAME} \ -Dnet.ladenthin.llama.vision.image=${VISION_IMAGE_PATH} @@ -721,6 +735,8 @@ jobs: run: curl -L --fail --retry 5 --retry-all-errors $env:DRAFT_MODEL_URL --create-dirs -o models/$env:DRAFT_MODEL_NAME - name: Download reasoning model run: curl -L --fail --retry 5 --retry-all-errors $env:REASONING_MODEL_URL --create-dirs -o models/$env:REASONING_MODEL_NAME + - name: Download tool-calling model + run: curl -L --fail --retry 5 --retry-all-errors $env:TOOL_MODEL_URL --create-dirs -o models/$env:TOOL_MODEL_NAME - name: Download vision model (issues #103 / #34) run: curl -L --fail --retry 5 --retry-all-errors $env:VISION_MODEL_URL --create-dirs -o models/$env:VISION_MODEL_NAME - name: Download vision mmproj @@ -756,6 +772,7 @@ jobs: - name: Run tests run: | mvn -e --no-transfer-progress test ` + "-Dnet.ladenthin.llama.tool.model=models/$env:TOOL_MODEL_NAME" ` "-Dnet.ladenthin.llama.vision.model=models/$env:VISION_MODEL_NAME" ` "-Dnet.ladenthin.llama.vision.mmproj=models/$env:VISION_MMPROJ_NAME" ` "-Dnet.ladenthin.llama.vision.image=$env:VISION_IMAGE_PATH" diff --git a/.gitignore b/.gitignore index c7ba7df3..7b3a5e04 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,8 @@ src/test/resources/**/*.gbnf src/main/cpp/llama.cpp/ # jcstress / jqwik test outputs (generated in repo root) -/.jqwik-database \ No newline at end of file +/.jqwik-database + +# Local AI agent tooling (not part of the project) +AGENTS.md +.agents/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d2ea0f0..0358b9bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ from version 5.0.0 onward. Pre-fork releases (`1.x`–`4.2.0`) were authored by - `CODE_OF_CONDUCT.md` (Contributor Covenant 2.0). - `docs/RELEASE.md` capturing the maintainer-facing release procedure (moved out of CHANGELOG). - OpenSSF Best Practices badge (project 12862) on README. +- OpenAI-compatible `parallel_tool_calls` support: `ChatRequest.withParallelToolCalls(Boolean)` / `getParallelToolCalls()`, `InferenceParameters.withParallelToolCalls(boolean)`, and pass-through in the `/v1/chat/completions` server mapper. +- Real-model tool-calling integration tests for blocking and streaming required tool calls (`ToolCallingIntegrationTest`, Qwen2.5-1.5B-Instruct), wired into CI and `validate-models`. ### Changed - Unified `CONTRIBUTING.md` and `SECURITY.md` structure with sibling repositories in the project family. @@ -20,6 +22,7 @@ from version 5.0.0 onward. Pre-fork releases (`1.x`–`4.2.0`) were authored by - README license badge corrected from "Apache 2.0" to "MIT" (matches `LICENSE` file and `pom.xml`). - `pom.xml` SCM URL: `tree/master` → `tree/main` (default branch renamed). - Upgraded llama.cpp from b9151 to b9172. +- Extracted the `chatWithTools` agent loop into `ToolCallingAgent`; tool-result errors (unknown tool / handler exception) are now JSON-serialized so tool names containing special characters remain valid JSON. ### Added - Reasoning-budget tests (Qwen3-0.6B). diff --git a/README.md b/README.md index 3d687079..413a127c 100644 --- a/README.md +++ b/README.md @@ -259,7 +259,8 @@ Every `net.ladenthin.llama.*` system property recognised by the library, deep-sc | `net.ladenthin.llama.lib.path` | unset (falls back to `java.library.path`) | runtime | `LlamaLoader` | Directory containing the native `jllama` shared library. Checked first, before `java.library.path`. Set with `-Dnet.ladenthin.llama.lib.path=/path/to/dir`. | | `net.ladenthin.llama.tmpdir` | unset (falls back to `java.io.tmpdir`) | runtime | `LlamaLoader` | Custom temporary directory used when extracting the native library from the JAR. | | `net.ladenthin.llama.osinfo.architecture` | unset (uses `os.arch`) | runtime | `OSInfo` | Override for the architecture string used to locate the bundled library inside the JAR. Useful when `os.arch` reports an unexpected value (e.g. inside dockcross / chrooted environments). | -| `net.ladenthin.llama.test.ngl` | `43` | test | `LlamaModelTest`, `RerankingModelTest`, `ChatScenarioTest`, `ChatAdvancedTest`, `ErrorHandlingTest`, `SessionConcurrencyTest`, `ConfigureParallelInferenceTest`, `MultimodalIntegrationTest` (via `Integer.getInteger(TestConstants.PROP_TEST_NGL, TestConstants.DEFAULT_TEST_NGL)`) | Number of GPU layers used during testing. Pin to `0` on CPU-only hosts: `mvn test -Dnet.ladenthin.llama.test.ngl=0`. | +| `net.ladenthin.llama.test.ngl` | `43` for the general suite; `0` for `ToolCallingIntegrationTest` | test | Model-backed integration tests | Number of GPU layers used during testing. Pin to `0` on CPU-only hosts: `mvn test -Dnet.ladenthin.llama.test.ngl=0`. The tool test also selects device `none` at zero layers so Metal/CUDA is not initialized. | +| `net.ladenthin.llama.tool.model` | `models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf` (test self-skips if missing) | test | `ToolCallingIntegrationTest` | Path to a tool-capable GGUF used to verify required blocking and streaming tool calls. The default matches the Qwen2.5 model in upstream llama.cpp's tool-call test matrix. | | `net.ladenthin.llama.nomic.path` | unset (test self-skips) | test | `LlamaEmbeddingsTest#testNomicEmbedLoads` | Path to a Nomic embedding model (`nomic-embed-text-v1.5.f16.gguf` or a compatible BERT-family encoder). Regression test for upstream issue #98 (BERT-encoder `result_output` assertion). | | `net.ladenthin.llama.vision.model` | unset (test self-skips) | test | `MultimodalIntegrationTest` (closes #103 / #34) | Path to a vision-capable model GGUF. Any vision-capable GGUF works; CI default is `SmolVLM-500M-Instruct-Q8_0.gguf`. | | `net.ladenthin.llama.vision.mmproj` | unset (test self-skips) | test | `MultimodalIntegrationTest` | Matching mmproj GGUF for the vision model. | @@ -368,6 +369,40 @@ try (LlamaModel model = new LlamaModel(modelParams)) { Reasoning/thinking models can receive custom Jinja template variables via `ModelParameters#setChatTemplateKwargs(Map)`. +### Tool Calling + +Use a tool-aware instruct model and enable Jinja when loading it. A typed request can either return +the model's tool calls through `chat`, or execute registered handlers until the model produces a +normal assistant response through `chatWithTools`: + +```java +ToolDefinition weather = new ToolDefinition( + "get_weather", + "Get the current weather for a city", + "{\"type\":\"object\",\"properties\":{\"city\":{\"type\":\"string\"}}," + + "\"required\":[\"city\"]}"); + +ChatRequest request = ChatRequest.empty() + .appendMessage("user", "What is the weather in Paris?") + .appendTool(weather) + .withToolChoice("auto") + .withParallelToolCalls(Boolean.FALSE); + +Map handlers = Collections.singletonMap( + "get_weather", argumentsJson -> "{\"temperature_c\":21,\"condition\":\"sunny\"}"); + +try (LlamaModel model = new LlamaModel(new ModelParameters() + .setModel("models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf") + .enableJinja())) { + ChatResponse response = model.chatWithTools(request, handlers); + System.out.println(response.getFirstContent()); +} +``` + +`tool_choice` is the OpenAI-compatible string form (`auto`, `none`, or `required`). Set +`parallel_tool_calls` to `false` when handlers should be issued one at a time. Handler failures and +unknown tool names are returned to the model as valid `{"error":"..."}` tool-result JSON. + ### Infilling You can simply set `InferenceParameters#setInputPrefix(String)` and `InferenceParameters#setInputSuffix(String)`. diff --git a/src/main/java/net/ladenthin/llama/LlamaModel.java b/src/main/java/net/ladenthin/llama/LlamaModel.java index a298ca8e..2a1e131f 100644 --- a/src/main/java/net/ladenthin/llama/LlamaModel.java +++ b/src/main/java/net/ladenthin/llama/LlamaModel.java @@ -29,7 +29,6 @@ import net.ladenthin.llama.parameters.ChatRequest; import net.ladenthin.llama.parameters.InferenceParameters; import net.ladenthin.llama.parameters.ModelParameters; -import net.ladenthin.llama.value.ChatMessage; import net.ladenthin.llama.value.ChatResponse; import net.ladenthin.llama.value.CompletionResult; import net.ladenthin.llama.value.LlamaOutput; @@ -38,7 +37,6 @@ import net.ladenthin.llama.value.Pair; import net.ladenthin.llama.value.ServerMetrics; import net.ladenthin.llama.value.StopReason; -import net.ladenthin.llama.value.ToolCall; import org.jspecify.annotations.Nullable; /** @@ -551,6 +549,10 @@ public ChatResponse chat(ChatRequest request) { if (toolChoice.isPresent()) { params = params.withToolChoice(toolChoice.get()); } + Optional parallelToolCalls = request.getParallelToolCalls(); + if (parallelToolCalls.isPresent()) { + params = params.withParallelToolCalls(parallelToolCalls.get()); + } } params = request.applyCustomizer(params); String raw = chatComplete(params); @@ -575,42 +577,7 @@ public ChatResponse chat(ChatRequest request) { * (or the last response when the round cap is hit) */ public ChatResponse chatWithTools(ChatRequest request, java.util.Map handlers) { - final int maxRounds = request.getMaxToolRounds(); - if (maxRounds < 1) { - throw new IllegalArgumentException("ChatRequest.maxToolRounds must be >= 1 (got " + maxRounds + "); " - + "chatWithTools always issues at least one chat call."); - } - ChatRequest current = request; - ChatResponse last = chat(current); - for (int round = 1; round < maxRounds; round++) { - Optional assistantOpt = last.getFirstMessage(); - // NOTE: inline !isPresent() here (not compatibilityHelper.isEmpty) so NullAway's - // CheckOptionalEmptiness recognises this as null-narrowing for the .get() below. - if (!assistantOpt.isPresent() || assistantOpt.get().getToolCalls().isEmpty()) { - return last; - } - ChatMessage assistant = assistantOpt.get(); - current = current.appendMessage(assistant); - for (ToolCall call : assistant.getToolCalls()) { - ToolHandler handler = handlers.get(call.getName()); - String result; - if (handler == null) { - result = "{\"error\":\"unknown tool: " + call.getName() + "\"}"; - } else { - try { - result = handler.invoke(call.getArgumentsJson()); - } catch (Exception e) { - result = "{\"error\":" - + net.ladenthin.llama.json.ChatResponseParser.OBJECT_MAPPER.valueToTree( - e.getClass().getSimpleName() + ": " + e.getMessage()) - + "}"; - } - } - current = current.appendMessage(ChatMessage.toolResult(call.getId(), result)); - } - last = chat(current); - } - return last; + return ToolCallingAgent.run(request, handlers, this::chat); } /** diff --git a/src/main/java/net/ladenthin/llama/ToolCallingAgent.java b/src/main/java/net/ladenthin/llama/ToolCallingAgent.java new file mode 100644 index 00000000..b69653ad --- /dev/null +++ b/src/main/java/net/ladenthin/llama/ToolCallingAgent.java @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama; + +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import net.ladenthin.llama.callback.ToolHandler; +import net.ladenthin.llama.parameters.ChatRequest; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ChatResponse; +import net.ladenthin.llama.value.ToolCall; + +/** Model-independent orchestration for the tool-calling agent loop. */ +final class ToolCallingAgent { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private ToolCallingAgent() {} + + static ChatResponse run( + ChatRequest request, Map handlers, Function chatCall) { + final int maxRounds = request.getMaxToolRounds(); + if (maxRounds < 1) { + throw new IllegalArgumentException("ChatRequest.maxToolRounds must be >= 1 (got " + maxRounds + "); " + + "chatWithTools always issues at least one chat call."); + } + + ChatRequest current = request; + ChatResponse last = chatCall.apply(current); + for (int round = 1; round < maxRounds; round++) { + Optional assistantOpt = last.getFirstMessage(); + if (!assistantOpt.isPresent() || assistantOpt.get().getToolCalls().isEmpty()) { + return last; + } + + ChatMessage assistant = assistantOpt.get(); + current = current.appendMessage(assistant); + for (ToolCall call : assistant.getToolCalls()) { + current = current.appendMessage(ChatMessage.toolResult(call.getId(), invoke(call, handlers))); + } + last = chatCall.apply(current); + } + return last; + } + + private static String invoke(ToolCall call, Map handlers) { + ToolHandler handler = handlers.get(call.getName()); + if (handler == null) { + return errorJson("unknown tool: " + call.getName()); + } + try { + return handler.invoke(call.getArgumentsJson()); + } catch (Exception e) { + return errorJson(e.getClass().getSimpleName() + ": " + e.getMessage()); + } + } + + private static String errorJson(String message) { + return MAPPER.createObjectNode().put("error", message).toString(); + } +} diff --git a/src/main/java/net/ladenthin/llama/parameters/ChatRequest.java b/src/main/java/net/ladenthin/llama/parameters/ChatRequest.java index 23173bc3..828d680b 100644 --- a/src/main/java/net/ladenthin/llama/parameters/ChatRequest.java +++ b/src/main/java/net/ladenthin/llama/parameters/ChatRequest.java @@ -64,7 +64,7 @@ *

Equality

* *

{@code @EqualsAndHashCode} compares messages, tools, {@code toolChoice}, - * and {@code maxToolRounds} by value. The {@code paramsCustomizer} + * {@code parallelToolCalls}, and {@code maxToolRounds} by value. The {@code paramsCustomizer} * {@link UnaryOperator} is excluded from equality: lambdas have * compiler-synthesised identity equality which is not value-shaped, so * including it would mean two structurally-identical requests with the same @@ -88,12 +88,14 @@ public final class ChatRequest { Collections.emptyList(), Collections.emptyList(), null, + null, DEFAULT_MAX_TOOL_ROUNDS, null); private final List messages; private final List tools; private final @Nullable String toolChoice; + private final @Nullable Boolean parallelToolCalls; private final int maxToolRounds; // Lambda Consumer — toString is the implementation hash, not useful in logs; @@ -111,11 +113,13 @@ private ChatRequest( List messages, List tools, @Nullable String toolChoice, + @Nullable Boolean parallelToolCalls, int maxToolRounds, @Nullable UnaryOperator paramsCustomizer) { this.messages = messages; this.tools = tools; this.toolChoice = toolChoice; + this.parallelToolCalls = parallelToolCalls; this.maxToolRounds = maxToolRounds; this.paramsCustomizer = paramsCustomizer; } @@ -145,7 +149,13 @@ public ChatRequest appendMessage(ChatMessage message) { List next = new ArrayList(messages.size() + 1); next.addAll(messages); next.add(message); - return new ChatRequest(Collections.unmodifiableList(next), tools, toolChoice, maxToolRounds, paramsCustomizer); + return new ChatRequest( + Collections.unmodifiableList(next), + tools, + toolChoice, + parallelToolCalls, + maxToolRounds, + paramsCustomizer); } /** @@ -171,7 +181,12 @@ public ChatRequest appendTool(ToolDefinition tool) { next.addAll(tools); next.add(tool); return new ChatRequest( - messages, Collections.unmodifiableList(next), toolChoice, maxToolRounds, paramsCustomizer); + messages, + Collections.unmodifiableList(next), + toolChoice, + parallelToolCalls, + maxToolRounds, + paramsCustomizer); } // ----------------------------------------------------------------------- @@ -186,7 +201,18 @@ public ChatRequest appendTool(ToolDefinition tool) { * @return a new request with the hint replaced; this request is unchanged */ public ChatRequest withToolChoice(@Nullable String newToolChoice) { - return new ChatRequest(messages, tools, newToolChoice, maxToolRounds, paramsCustomizer); + return new ChatRequest(messages, tools, newToolChoice, parallelToolCalls, maxToolRounds, paramsCustomizer); + } + + /** + * Returns a new request with the {@code parallel_tool_calls} hint replaced. + * + * @param newParallelToolCalls whether the model may emit multiple calls in one turn, + * or {@code null} to use the model/template default + * @return a new request with the hint replaced; this request is unchanged + */ + public ChatRequest withParallelToolCalls(@Nullable Boolean newParallelToolCalls) { + return new ChatRequest(messages, tools, toolChoice, newParallelToolCalls, maxToolRounds, paramsCustomizer); } /** @@ -200,7 +226,7 @@ public ChatRequest withMaxToolRounds(int newMaxToolRounds) { if (newMaxToolRounds <= 0) { throw new IllegalArgumentException("maxToolRounds must be > 0 but was " + newMaxToolRounds); } - return new ChatRequest(messages, tools, toolChoice, newMaxToolRounds, paramsCustomizer); + return new ChatRequest(messages, tools, toolChoice, parallelToolCalls, newMaxToolRounds, paramsCustomizer); } /** @@ -210,7 +236,7 @@ public ChatRequest withMaxToolRounds(int newMaxToolRounds) { * @return a new request with the customiser replaced; this request is unchanged */ public ChatRequest withInferenceCustomizer(@Nullable UnaryOperator newCustomizer) { - return new ChatRequest(messages, tools, toolChoice, maxToolRounds, newCustomizer); + return new ChatRequest(messages, tools, toolChoice, parallelToolCalls, maxToolRounds, newCustomizer); } // ----------------------------------------------------------------------- @@ -244,6 +270,15 @@ public Optional getToolChoice() { return Optional.ofNullable(toolChoice); } + /** + * Parallel-tool-call hint accessor. + * + * @return the {@code parallel_tool_calls} hint, or {@link Optional#empty()} when unset + */ + public Optional getParallelToolCalls() { + return Optional.ofNullable(parallelToolCalls); + } + /** * Agent-loop round cap accessor. * diff --git a/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java b/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java index 824965de..6da0c49d 100644 --- a/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java @@ -98,6 +98,7 @@ public final class InferenceParameters extends JsonParameters { private static final String PARAM_CONTINUE_FINAL_MESSAGE = "continue_final_message"; private static final String PARAM_TOOLS = "tools"; private static final String PARAM_TOOL_CHOICE = "tool_choice"; + private static final String PARAM_PARALLEL_TOOL_CALLS = "parallel_tool_calls"; private static final InferenceParameters EMPTY = new InferenceParameters(); @@ -653,6 +654,16 @@ public InferenceParameters withToolChoice(@Nullable String toolChoice) { return withOptionalJson(PARAM_TOOL_CHOICE, toolChoice); } + /** + * Returns a new request with the OpenAI-compatible {@code parallel_tool_calls} flag replaced. + * + * @param parallelToolCalls whether the model may emit more than one tool call in a turn + * @return a new instance; this instance is unchanged + */ + public InferenceParameters withParallelToolCalls(boolean parallelToolCalls) { + return withScalar(PARAM_PARALLEL_TOOL_CALLS, parallelToolCalls); + } + /** * Returns a new request with the top-n-sigma threshold replaced (default: -1.0, disabled). * diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java b/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java index aa01487a..e4424c91 100644 --- a/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java +++ b/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java @@ -12,7 +12,8 @@ /** * Pure mapping from an OpenAI {@code /v1/chat/completions} request body to {@link InferenceParameters}. * - *

The structural fields — {@code messages}, {@code tools}, {@code tool_choice} — are forwarded + *

The structural fields — {@code messages}, {@code tools}, {@code tool_choice}, and + * {@code parallel_tool_calls} — are forwarded * verbatim as raw JSON so the full OpenAI shape (assistant {@code tool_calls}, * {@code role:"tool"} results with {@code tool_call_id}, and vision {@code image_url} content parts) * round-trips untouched into the native chat-template parser. Sampling fields are translated to the @@ -49,6 +50,10 @@ InferenceParameters toInferenceParameters(JsonNode request) { if (toolChoice.isTextual()) { params = params.withToolChoice(toolChoice.asText()); } + JsonNode parallelToolCalls = request.path("parallel_tool_calls"); + if (parallelToolCalls.isBoolean()) { + params = params.withParallelToolCalls(parallelToolCalls.asBoolean()); + } } JsonNode temperature = request.path("temperature"); diff --git a/src/test/java/net/ladenthin/llama/TestConstants.java b/src/test/java/net/ladenthin/llama/TestConstants.java index 566c6d56..a4976a6a 100644 --- a/src/test/java/net/ladenthin/llama/TestConstants.java +++ b/src/test/java/net/ladenthin/llama/TestConstants.java @@ -23,6 +23,12 @@ public class TestConstants { /** Path to the Qwen3 thinking model used for reasoning budget tests. */ public static final String REASONING_MODEL_PATH = "models/Qwen3-0.6B-Q4_K_M.gguf"; + /** System property overriding the GGUF used by the real tool-calling integration tests. */ + public static final String PROP_TOOL_MODEL_PATH = LlamaSystemProperties.PREFIX + ".tool.model"; + + /** Qwen2.5 tool-capable model used by upstream llama.cpp's blocking and streaming tests. */ + public static final String DEFAULT_TOOL_MODEL_PATH = "models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf"; + /** * System property holding a path to a Nomic embedding model * ({@code nomic-embed-text-v1.5.f16.gguf} or a compatible BERT-family encoder). diff --git a/src/test/java/net/ladenthin/llama/ToolCallingAgentTest.java b/src/test/java/net/ladenthin/llama/ToolCallingAgentTest.java new file mode 100644 index 00000000..ff000b8d --- /dev/null +++ b/src/test/java/net/ladenthin/llama/ToolCallingAgentTest.java @@ -0,0 +1,174 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import net.ladenthin.llama.callback.ToolHandler; +import net.ladenthin.llama.parameters.ChatRequest; +import net.ladenthin.llama.value.ChatChoice; +import net.ladenthin.llama.value.ChatMessage; +import net.ladenthin.llama.value.ChatResponse; +import net.ladenthin.llama.value.Timings; +import net.ladenthin.llama.value.ToolCall; +import net.ladenthin.llama.value.Usage; +import org.junit.jupiter.api.Test; + +class ToolCallingAgentTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @Test + void invokesHandlerAndAppendsAssistantAndToolResultTurns() { + ToolCall call = new ToolCall("call-1", "weather", "{\"city\":\"Paris\"}"); + List requests = new ArrayList(); + List responses = + new ArrayList(Arrays.asList(toolResponse(call), textResponse("It is sunny."))); + AtomicInteger invocations = new AtomicInteger(); + Map handlers = Collections.singletonMap("weather", args -> { + invocations.incrementAndGet(); + assertThat(args, is("{\"city\":\"Paris\"}")); + return "{\"condition\":\"sunny\"}"; + }); + + ChatResponse result = ToolCallingAgent.run( + ChatRequest.empty().appendMessage("user", "Weather?").withMaxToolRounds(2), handlers, request -> { + requests.add(request); + return responses.remove(0); + }); + + assertThat(result.getFirstContent(), is("It is sunny.")); + assertThat(invocations.get(), is(1)); + assertThat(requests, hasSize(2)); + List secondRound = requests.get(1).getMessages(); + assertThat(secondRound, hasSize(3)); + assertThat(secondRound.get(1).getToolCalls().get(0), is(call)); + assertThat(secondRound.get(2).getRole(), is("tool")); + assertThat(secondRound.get(2).getToolCallId().orElseThrow(), is("call-1")); + assertThat(secondRound.get(2).getContent(), is("{\"condition\":\"sunny\"}")); + } + + @Test + void invokesEveryToolCallInOneAssistantTurn() { + ToolCall first = new ToolCall("call-1", "first", "{}"); + ToolCall second = new ToolCall("call-2", "second", "{}"); + Map handlers = new HashMap(); + handlers.put("first", args -> "1"); + handlers.put("second", args -> "2"); + List requests = new ArrayList(); + List responses = new ArrayList( + Arrays.asList(toolResponse(Arrays.asList(first, second)), textResponse("done"))); + + ToolCallingAgent.run( + ChatRequest.empty().appendMessage("user", "both").withMaxToolRounds(2), handlers, request -> { + requests.add(request); + return responses.remove(0); + }); + + List secondRound = requests.get(1).getMessages(); + assertThat(secondRound, hasSize(4)); + assertThat(secondRound.get(2).getContent(), is("1")); + assertThat(secondRound.get(3).getContent(), is("2")); + } + + @Test + void unknownToolNameIsReturnedAsValidJson() throws IOException { + String result = captureToolResult( + new ToolCall("call-1", "bad\"name", "{}"), Collections.emptyMap()); + JsonNode parsed = MAPPER.readTree(result); + assertThat(parsed.path("error").asText(), is("unknown tool: bad\"name")); + } + + @Test + void handlerExceptionIsReturnedAsValidJson() throws IOException { + Map handlers = Collections.singletonMap("broken", args -> { + throw new IllegalStateException("bad \"value\""); + }); + String result = captureToolResult(new ToolCall("call-1", "broken", "{}"), handlers); + JsonNode parsed = MAPPER.readTree(result); + assertThat(parsed.path("error").asText(), is("IllegalStateException: bad \"value\"")); + } + + @Test + void roundCapStopsBeforeExecutingLastResponseCalls() { + AtomicInteger chatCalls = new AtomicInteger(); + AtomicInteger toolCalls = new AtomicInteger(); + ToolCall call = new ToolCall("call-1", "echo", "{}"); + + ChatResponse result = ToolCallingAgent.run( + ChatRequest.empty().appendMessage("user", "echo").withMaxToolRounds(1), + Collections.singletonMap("echo", args -> { + toolCalls.incrementAndGet(); + return args; + }), + request -> { + chatCalls.incrementAndGet(); + return toolResponse(call); + }); + + assertThat(result.getFirstMessage().orElseThrow().getToolCalls(), hasSize(1)); + assertThat(chatCalls.get(), is(1)); + assertThat(toolCalls.get(), is(0)); + } + + @Test + void responseWithoutToolCallsStopsImmediately() { + AtomicInteger chatCalls = new AtomicInteger(); + ChatResponse result = ToolCallingAgent.run( + ChatRequest.empty().appendMessage("user", "hi"), + Collections.emptyMap(), + request -> { + chatCalls.incrementAndGet(); + return textResponse("hello"); + }); + assertThat(result.getFirstContent(), is("hello")); + assertThat(chatCalls.get(), is(1)); + } + + private String captureToolResult(ToolCall call, Map handlers) { + List requests = new ArrayList(); + List responses = + new ArrayList(Arrays.asList(toolResponse(call), textResponse("done"))); + ToolCallingAgent.run( + ChatRequest.empty().appendMessage("user", "go").withMaxToolRounds(2), handlers, request -> { + requests.add(request); + return responses.remove(0); + }); + return requests.get(1).getMessages().get(2).getContent(); + } + + private static ChatResponse toolResponse(ToolCall call) { + return toolResponse(Collections.singletonList(call)); + } + + private static ChatResponse toolResponse(List calls) { + return response(ChatMessage.assistantToolCalls("", calls), "tool_calls"); + } + + private static ChatResponse textResponse(String text) { + return response(new ChatMessage("assistant", text), "stop"); + } + + private static ChatResponse response(ChatMessage message, String finishReason) { + return new ChatResponse( + "id", + Collections.singletonList(new ChatChoice(0, message, finishReason)), + new Usage(0, 0), + new Timings(0, 0, 0.0, 0.0, 0, 0.0, 0.0, 0, 0), + "{}"); + } +} diff --git a/src/test/java/net/ladenthin/llama/ToolCallingIntegrationTest.java b/src/test/java/net/ladenthin/llama/ToolCallingIntegrationTest.java new file mode 100644 index 00000000..83c3a6d2 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/ToolCallingIntegrationTest.java @@ -0,0 +1,124 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import net.ladenthin.llama.parameters.ChatRequest; +import net.ladenthin.llama.parameters.InferenceParameters; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.value.ChatResponse; +import net.ladenthin.llama.value.ToolCall; +import net.ladenthin.llama.value.ToolDefinition; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/** Real-model coverage for the blocking and streaming tool-call paths. */ +public class ToolCallingIntegrationTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final ToolDefinition TEST_TOOL = new ToolDefinition( + "test", + "", + "{\"type\":\"object\",\"properties\":{\"success\":{\"type\":\"boolean\",\"const\":true}}," + + "\"required\":[\"success\"]}"); + + private static LlamaModel model; + + @BeforeAll + public static void loadModel() { + String modelPath = + System.getProperty(TestConstants.PROP_TOOL_MODEL_PATH, TestConstants.DEFAULT_TOOL_MODEL_PATH); + Assumptions.assumeTrue(new File(modelPath).exists(), "Tool-calling model missing: " + modelPath); + int gpuLayers = Integer.getInteger(TestConstants.PROP_TEST_NGL, 0); + ModelParameters parameters = new ModelParameters() + .setModel(modelPath) + .setCtxSize(8192) + .setGpuLayers(gpuLayers) + .setFit(false) + .enableJinja(); + if (gpuLayers == 0) { + parameters.setDevices("none"); + } + model = new LlamaModel(parameters); + } + + @AfterAll + public static void closeModel() { + if (model != null) { + model.close(); + } + } + + @Test + public void requiredToolCallIsParsedFromBlockingResponse() throws IOException { + ChatResponse response = model.chat(toolRequest()); + + List calls = response.getFirstMessage().orElseThrow().getToolCalls(); + assertThat(calls, hasSize(1)); + assertThat(calls.get(0).getName(), is("test")); + assertThat( + MAPPER.readTree(calls.get(0).getArgumentsJson()).path("success").asBoolean(), is(true)); + } + + @Test + public void requiredToolCallIsParsedFromStreamingResponse() throws IOException { + ChatRequest request = toolRequest(); + InferenceParameters params = request.applyCustomizer(InferenceParameters.empty() + .withMessagesJson(request.buildMessagesJson()) + .withToolsJson(request.buildToolsJson().orElseThrow()) + .withToolChoice(request.getToolChoice().orElseThrow()) + .withParallelToolCalls(request.getParallelToolCalls().orElseThrow()) + .withUseChatTemplate(true)); + List chunks = new ArrayList(); + + model.streamChatCompletion(params, chunks::add); + + StringBuilder name = new StringBuilder(); + StringBuilder arguments = new StringBuilder(); + for (String chunk : chunks) { + JsonNode toolCalls = + MAPPER.readTree(chunk).path("choices").path(0).path("delta").path("tool_calls"); + if (!toolCalls.isArray()) { + continue; + } + for (JsonNode call : toolCalls) { + JsonNode function = call.path("function"); + if (function.path("name").isTextual()) { + name.append(function.path("name").asText()); + } + if (function.path("arguments").isTextual()) { + arguments.append(function.path("arguments").asText()); + } + } + } + + assertThat(name.toString(), is("test")); + assertThat(MAPPER.readTree(arguments.toString()).path("success").asBoolean(), is(true)); + } + + private static ChatRequest toolRequest() { + return ChatRequest.empty() + .appendMessage("system", "You are a coding assistant.") + .appendMessage("user", "Write an example") + .appendTool(TEST_TOOL) + .withToolChoice("required") + .withParallelToolCalls(Boolean.FALSE) + .withInferenceCustomizer(params -> params.withNPredict(512) + .withTemperature(0.0f) + .withTopK(1) + .withTopP(1.0f)); + } +} diff --git a/src/test/java/net/ladenthin/llama/parameters/ChatRequestTest.java b/src/test/java/net/ladenthin/llama/parameters/ChatRequestTest.java index 388d0502..16d1c67f 100644 --- a/src/test/java/net/ladenthin/llama/parameters/ChatRequestTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/ChatRequestTest.java @@ -57,6 +57,15 @@ void withToolChoiceReturnsNewInstance() { assertThat(derived.getToolChoice().orElseThrow(), is("auto")); } + @Test + void withParallelToolCallsReturnsNewInstance() { + ChatRequest original = ChatRequest.empty(); + ChatRequest derived = original.withParallelToolCalls(Boolean.FALSE); + assertThat(derived, is(not(sameInstance(original)))); + assertThat("original hint unset", original.getParallelToolCalls().isPresent(), is(false)); + assertThat(derived.getParallelToolCalls().orElseThrow(), is(false)); + } + @Test void withMaxToolRoundsReturnsNewInstance() { ChatRequest original = ChatRequest.empty(); @@ -137,6 +146,13 @@ void differentMaxToolRoundsNotEqual() { assertThat(a, is(not(b))); } + @Test + void differentParallelToolCallsNotEqual() { + ChatRequest a = ChatRequest.empty().withParallelToolCalls(Boolean.TRUE); + ChatRequest b = ChatRequest.empty().withParallelToolCalls(Boolean.FALSE); + assertThat(a, is(not(b))); + } + @Test @DisplayName( "the customiser is excluded from equality — two requests with the same content but different lambdas are equal") diff --git a/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java b/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java index 33b7f494..df0d4cef 100644 --- a/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java @@ -73,6 +73,12 @@ public void testSetNPredict() { assertThat(params.parameters.get("n_predict"), is("42")); } + @Test + public void testSetParallelToolCalls() { + InferenceParameters params = new InferenceParameters("").withParallelToolCalls(false); + assertThat(params.parameters.get("parallel_tool_calls"), is("false")); + } + @Test public void testSetTemperature() { InferenceParameters params = new InferenceParameters("").withTemperature(0.5f); diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java b/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java index 9813b5d9..e9a3e27f 100644 --- a/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java +++ b/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java @@ -98,6 +98,14 @@ public void toolsEnableChatTemplateAndForwardChoice() throws IOException { assertThat(out.path("use_jinja").asBoolean(), is(true)); } + @Test + public void parallelToolCallsForwarded() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}]," + + "\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"a\"}}]," + + "\"parallel_tool_calls\":false}"); + assertThat(out.path("parallel_tool_calls").asBoolean(), is(false)); + } + @Test public void stopAsSingleStringMapped() throws IOException { JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}],\"stop\":\"END\"}");