diff --git a/.clang-format b/.clang-format index 7daa752e..ae609f09 100644 --- a/.clang-format +++ b/.clang-format @@ -175,7 +175,10 @@ RequiresClausePosition: OwnLine RequiresExpressionIndentation: OuterScope SeparateDefinitionBlocks: Leave ShortNamespaceLines: 1 -SortIncludes: CaseSensitive +# Never reorder #include lines: this project has order-sensitive includes — the upstream +# server-*.h headers must precede json_helpers.hpp / jni_helpers.hpp (which use the `json` +# alias those headers define). Alphabetical sorting breaks the build (json undefined). +SortIncludes: Never SortJavaStaticImport: Before SortUsingDeclarations: LexicographicNumeric SpaceAfterCStyleCast: false diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml new file mode 100644 index 00000000..ee81fbce --- /dev/null +++ b/.github/workflows/clang-format.yml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: 2026 Bernard Ladenthin +# +# SPDX-License-Identifier: MIT + +name: clang-format +on: + push: + pull_request: + workflow_dispatch: + +# Enforces a single, pinned clang-format across all C++ sources so formatting is +# reproducible between contributors and CI. Bump CLANG_FORMAT_VERSION here and in +# CLAUDE.md (Code Formatting) together, then reformat the tree with the same version. +env: + CLANG_FORMAT_VERSION: "22.1.5" + +jobs: + clang-format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - name: Install pinned clang-format + run: pip install "clang-format==${CLANG_FORMAT_VERSION}" + - name: Check C++ formatting + run: | + clang-format --version + # All hand-written C++ sources; the generated JNI header (src/main/cpp/jllama.h, + # produced by `javac -h`) is intentionally excluded. + files=$(find src/main/cpp src/test/cpp -type f \( -name '*.cpp' -o -name '*.hpp' \) | sort) + echo "Checking:"; echo "$files" + clang-format --dry-run --Werror $files diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d6bccb4e..3596ec3b 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -398,19 +398,19 @@ jobs: name: Linux-x86_64-libraries path: ${{ github.workspace }}/src/main/resources/net/ladenthin/llama/ - name: Download text generation model - run: curl -L --fail ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Download reranking model - run: curl -L --fail ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - name: Download draft model - run: curl -L --fail ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model - run: curl -L --fail ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} - name: Download nomic embedding model (issue #98 regression) - run: curl -L --fail ${NOMIC_EMBED_MODEL_URL} --create-dirs -o models/${NOMIC_EMBED_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${NOMIC_EMBED_MODEL_URL} --create-dirs -o models/${NOMIC_EMBED_MODEL_NAME} - name: Download vision model (issues #103 / #34) - run: curl -L --fail ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} - name: Download vision mmproj - run: curl -L --fail ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} - name: List files in models directory run: ls -l models/ - name: Validate model files @@ -519,17 +519,17 @@ jobs: name: macos-14-libraries path: ${{ github.workspace }}/src/main/resources/net/ladenthin/llama/ - name: Download text generation model - run: curl -L --fail ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Download reranking model - run: curl -L --fail ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - name: Download draft model - run: curl -L --fail ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model - run: curl -L --fail ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} - name: Download vision model (issues #103 / #34) - run: curl -L --fail ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} - name: Download vision mmproj - run: curl -L --fail ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} - name: List files in models directory run: ls -l models/ - name: Validate model files @@ -583,17 +583,17 @@ jobs: name: macos-15-libraries path: ${{ github.workspace }}/src/main/resources/net/ladenthin/llama/ - name: Download text generation model - run: curl -L --fail ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Download reranking model - run: curl -L --fail ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - name: Download draft model - run: curl -L --fail ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model - run: curl -L --fail ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} - name: Download vision model (issues #103 / #34) - run: curl -L --fail ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} - name: Download vision mmproj - run: curl -L --fail ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} - name: List files in models directory run: ls -l models/ - name: Validate model files @@ -647,17 +647,17 @@ jobs: name: macos-15-metal-libraries path: ${{ github.workspace }}/src/main/resources/net/ladenthin/llama/ - name: Download text generation model - run: curl -L --fail ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Download reranking model - run: curl -L --fail ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - name: Download draft model - run: curl -L --fail ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${DRAFT_MODEL_URL} --create-dirs -o models/${DRAFT_MODEL_NAME} - name: Download reasoning model - run: curl -L --fail ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME} - name: Download vision model (issues #103 / #34) - run: curl -L --fail ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MODEL_URL} --create-dirs -o models/${VISION_MODEL_NAME} - name: Download vision mmproj - run: curl -L --fail ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} + run: curl -L --fail --retry 5 --retry-all-errors ${VISION_MMPROJ_URL} --create-dirs -o models/${VISION_MMPROJ_NAME} - name: List files in models directory run: ls -l models/ - name: Validate model files @@ -714,17 +714,17 @@ jobs: name: Windows-x86_64-libraries path: ${{ github.workspace }}/src/main/resources/net/ladenthin/llama/ - name: Download text generation model - run: curl -L --fail $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME + run: curl -L --fail --retry 5 --retry-all-errors $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Download reranking model - run: curl -L --fail $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME + run: curl -L --fail --retry 5 --retry-all-errors $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME - name: Download draft model - run: curl -L --fail $env:DRAFT_MODEL_URL --create-dirs -o models/$env:DRAFT_MODEL_NAME + run: curl -L --fail --retry 5 --retry-all-errors $env:DRAFT_MODEL_URL --create-dirs -o models/$env:DRAFT_MODEL_NAME - name: Download reasoning model - run: curl -L --fail $env:REASONING_MODEL_URL --create-dirs -o models/$env:REASONING_MODEL_NAME + run: curl -L --fail --retry 5 --retry-all-errors $env:REASONING_MODEL_URL --create-dirs -o models/$env:REASONING_MODEL_NAME - name: Download vision model (issues #103 / #34) - run: curl -L --fail $env:VISION_MODEL_URL --create-dirs -o models/$env:VISION_MODEL_NAME + run: curl -L --fail --retry 5 --retry-all-errors $env:VISION_MODEL_URL --create-dirs -o models/$env:VISION_MODEL_NAME - name: Download vision mmproj - run: curl -L --fail $env:VISION_MMPROJ_URL --create-dirs -o models/$env:VISION_MMPROJ_NAME + run: curl -L --fail --retry 5 --retry-all-errors $env:VISION_MMPROJ_URL --create-dirs -o models/$env:VISION_MMPROJ_NAME - name: List files in models directory run: ls -l models/ - name: Validate model files diff --git a/CLAUDE.md b/CLAUDE.md index 84edd5a8..85b4a1fb 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -393,10 +393,28 @@ not track the loader's own Java package). This is the same `spotbugs-exclude.xml`, PIT `targetClasses`, and `CMakeLists.txt` OSInfo repairs. ### Code Formatting + +C++ formatting is **enforced in CI** (`.github/workflows/clang-format.yml`) with a **pinned** +clang-format — currently **22.1.5**, installed via `pip install clang-format==22.1.5`. Format with +that exact version before committing; a different clang-format version reflows code differently and +will fail the check. + ```bash -clang-format -i src/main/cpp/*.cpp src/main/cpp/*.hpp # Format C++ code +pip install "clang-format==22.1.5" +clang-format -i src/main/cpp/*.cpp src/main/cpp/*.hpp src/test/cpp/*.cpp # Format C++ code ``` +The generated JNI header `src/main/cpp/jllama.h` (produced by `javac -h`) is intentionally excluded. +To bump the enforced version, update the pin in **both** the workflow (`CLANG_FORMAT_VERSION`) and +this line, then reformat the whole tree with the new version in the same commit. + +**`.clang-format` sets `SortIncludes: Never` — do not re-enable include sorting.** The project has +order-sensitive includes (see the "Include order rule" above): the upstream `server-*.h` headers and +`utils.hpp` must precede `json_helpers.hpp` / `jni_helpers.hpp`, which use the `json` alias those +headers define. Alphabetical sorting moves the helper headers first and breaks the build with +`'json' does not name a type` (it slips past a local build whose toolchain resolves `json` anyway, +but fails the manylinux/aarch64/Android CI compilers). Keep include order manual. + ### Javadoc — must build cleanly before `mvn package` The release packaging job runs `mvn package` with the `release` profile, which attaches @@ -453,7 +471,9 @@ If the local check passes (`BUILD SUCCESS`), the `mvn package` job in - `LlamaIterator` / `LlamaIterable` — Streaming generation via Java `Iterator`/`Iterable`. - `LlamaLoader` — Extracts the platform-specific native library from the JAR to a temp directory, or finds it on `java.library.path`. - `OSInfo` — Detects OS and architecture for library resolution. -- `server.LlamaServer` — Optional OpenAI-compatible HTTP server and the fat-jar `Main-Class`. `LlamaServerArgs` parses the CLI; `OaiRouter` / `OaiHttpServer` (NanoHTTPD) map `POST /v1/chat/completions`, `/v1/completions`, `/v1/embeddings` and `GET /v1/models` to the `LlamaModel.handle*` methods. NanoHTTPD is an `` dependency (bundled only in the fat jar, not inherited by library consumers). The `server` package is a dedicated top layer in the ArchUnit `layeredArchitecture` rule (the only layer allowed to access the root `Api`). See README "OpenAI-compatible HTTP server". +- **`server` package — OpenAI-compatible HTTP endpoint. NOTE: two implementations coexist on this branch pending a "best of both" consolidation (see [`TODO.md`](TODO.md)).** + - `server.OpenAiCompatServer` — built on the JDK's `com.sun.net.httpserver` (no new dependency). Serves `POST /v1/chat/completions` (streaming via SSE + non-streaming) and `GET /v1/models` by delegating to `LlamaModel.chatComplete` / `LlamaModel.streamChatCompletion`, so editors that speak the OpenAI protocol (e.g. VS Code Copilot "Custom Endpoint") can drive a local model. Streaming uses the native OAI chunk path (`requestChatCompletionStream` / `receiveChatCompletionChunk`), preserving `delta.tool_calls`. + - `server.LlamaServer` — an OpenAI-compatible HTTP server and the fat-jar `Main-Class`. `LlamaServerArgs` parses the CLI; `OaiRouter` / `OaiHttpServer` (NanoHTTPD) map `POST /v1/chat/completions`, `/v1/completions`, `/v1/embeddings` and `GET /v1/models` to the `LlamaModel.handle*` methods. NanoHTTPD is an `` dependency (bundled only in the fat jar, not inherited by library consumers). The `server` package is a dedicated top layer in the ArchUnit `layeredArchitecture` rule (the only layer allowed to access the root `Api`). See README "OpenAI-compatible HTTP server". **Native layer** (`src/main/cpp/`): - `jllama.cpp` — JNI implementation bridging Java calls to llama.cpp. ~1,215 lines; 17 native methods. @@ -478,7 +498,7 @@ The project C++ helpers follow a strict semantic split: Functions: `get_result_error_message`, `results_to_json`, `rerank_results_to_json`, `parse_encoding_format`, `extract_embedding_prompt`, `is_infill_request`, -`parse_slot_prompt_similarity`, `parse_positive_int_config`. +`parse_slot_prompt_similarity`, `parse_positive_int_config`, `wrap_stream_chunk`. **`log_helpers.hpp`** — Pure log-formatting transforms. - Input: `ggml_log_level`, message text (`const char*`), an explicit `std::time_t` timestamp. @@ -584,11 +604,11 @@ ctest --test-dir build --output-on-failure -R "ResultsToJson" |------|-------|-------| | `src/test/cpp/test_utils.cpp` | 156 | Upstream helpers: `server_tokens`, `server_grammar_trigger`, `gen_tool_call_id`, `json_value`, `json_get_nested_values`, UTF-8 helpers, `format_response_rerank`, `format_embeddings_response_oaicompat`, `oaicompat_completion_params_parse`, `oaicompat_chat_params_parse`, `are_lora_equal`, `strip_flag_from_argv`, `token_piece_value`, `json_is_array_and_contains_numbers`, `format_oai_sse`, `format_oai_resp_sse`, `format_anthropic_sse` | | `src/test/cpp/test_server.cpp` | 188 | Upstream result types: `result_timings`, `task_params::to_json()` (incl. `dry_sequence_breakers`, `preserved_tokens`, `timings_per_token`), `completion_token_output`, `server_task_result_cmpl_partial` (non-oaicompat + `to_json_oaicompat` + logprobs + `to_json_oaicompat_chat` + `to_json_anthropic` + dispatcher), `server_task_result_cmpl_final` (non-oaicompat + `to_json_oaicompat` + `to_json_oaicompat_chat` + `to_json_oaicompat_chat_stream` + `to_json_anthropic` + `to_json_anthropic_stream` + tool_calls + dispatcher), `server_task_result_embd`, `server_task_result_rerank`, `server_task_result_metrics`, `server_task_result_slot_save_load`, `server_task_result_slot_erase`, `server_task_result_apply_lora`, `server_task_result_error`, `format_error_response`, `server_task::need_sampling()`, `server_task::n_tokens()`, `server_task::params_from_json_cmpl()` (parsing pipeline + grammar routing + error paths), `response_fields` projection | -| `src/test/cpp/test_json_helpers.cpp` | 42 | All functions in `json_helpers.hpp`: `get_result_error_message`, `results_to_json`, `rerank_results_to_json`, `parse_encoding_format`, `extract_embedding_prompt`, `is_infill_request`, `parse_slot_prompt_similarity`, `parse_positive_int_config` | +| `src/test/cpp/test_json_helpers.cpp` | 47 | All functions in `json_helpers.hpp`: `get_result_error_message`, `results_to_json`, `rerank_results_to_json`, `parse_encoding_format`, `extract_embedding_prompt`, `is_infill_request`, `parse_slot_prompt_similarity`, `parse_positive_int_config`, `wrap_stream_chunk` | | `src/test/cpp/test_log_helpers.cpp` | 13 | All functions in `log_helpers.hpp`: `log_level_name`, `format_log_as_json` | | `src/test/cpp/test_jni_helpers.cpp` | 41 | All functions in `jni_helpers.hpp` using a zero-filled `JNINativeInterface_` mock | -**Current total: 440 tests (all passing).** +**Current total: 445 tests (all passing).** #### Upstream source location (in CMake build tree) diff --git a/CMakeLists.txt b/CMakeLists.txt index 28186d3c..d3c688e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -355,5 +355,14 @@ if(BUILD_TESTING) SERVER_VERBOSE=$ ) - gtest_discover_tests(jllama_test) + # gtest_discover_tests runs the freshly built jllama_test executable at build + # time (POST_BUILD) to enumerate test cases. The default discovery timeout is + # 5s. The 32-bit Windows (Win32) build links the entire llama/ggml/server tree + # statically into one large binary whose startup + test enumeration sits right + # at that 5s boundary on shared CI runners: the same b9682 binary discovered + # within 5s in one run but was killed at the 5s timeout in another (empty + # output, process still alive — a timeout, not a crash). x64/Linux/macOS finish + # well under the default. Raise the budget so 32-bit discovery is not flaky; + # this is a maximum, so fast platforms still return immediately. + gtest_discover_tests(jllama_test DISCOVERY_TIMEOUT 120) endif() diff --git a/README.md b/README.md index ec7bce30..3d687079 100644 --- a/README.md +++ b/README.md @@ -399,6 +399,83 @@ Server state is exposed via `getMetrics()`, `eraseSlot(int)`, `saveSlot(int, Str ### OpenAI-compatible HTTP server +> **Note — two implementations pending consolidation.** This branch currently ships **two** +> independent OpenAI-compatible servers in `net.ladenthin.llama.server`, awaiting a "best of both" +> merge (see [TODO.md](TODO.md)). Both are documented below; they will be unified into one. + +#### Option A — `OpenAiCompatServer` (dependency-free, streaming SSE) — for VS Code Copilot and other OpenAI clients + +`net.ladenthin.llama.server.OpenAiCompatServer` turns a loaded model into a local +OpenAI-compatible HTTP endpoint using only the JDK's built-in `com.sun.net.httpserver` — no extra +dependency and no separate server process. It serves: + +- `POST /v1/chat/completions` — streaming (Server-Sent Events) and non-streaming, forwarding + `messages`/`tools` verbatim. The streaming path carries `delta.tool_calls`, so agent/tool-calling + clients work. +- `GET /v1/models` — advertises the configured model id. + +Embed it in your app: + +```java +ModelParameters modelParams = new ModelParameters().setModel("models/model.gguf").setParallel(2); +OpenAiServerConfig config = OpenAiServerConfig.builder().port(8080).modelId("local-model").build(); +try (LlamaModel model = new LlamaModel(modelParams); + OpenAiCompatServer server = new OpenAiCompatServer(model, config).start()) { + Thread.currentThread().join(); // serve until interrupted +} +``` + +…or run it standalone: + +```bash +java -cp target/llama-.jar net.ladenthin.llama.server.OpenAiCompatServer \ + --model models/model.gguf --port 8080 --model-id local-model +``` + +Verify with curl: + +```bash +curl -N http://127.0.0.1:8080/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model":"local-model","stream":true,"messages":[{"role":"user","content":"hi"}]}' +``` + +**VS Code Copilot setup:** Command Palette → **Chat: Manage Language Models** → **Add Models** → +**Custom Endpoint**; enter a group name, a display name and any non-empty API key, and pick API type +**Chat Completions**. VS Code then opens `chatLanguageModels.json` — set the model `url` to your +endpoint (the host/port go here, not in the form): + +```json +[ + { + "name": "Local llama.cpp", + "vendor": "customendpoint", + "apiKey": "local-dummy-key", + "apiType": "chat-completions", + "models": [ + { + "id": "local-model", + "name": "Local model", + "url": "http://127.0.0.1:8080/v1/chat/completions", + "toolCalling": true, + "vision": false, + "maxInputTokens": 6144, + "maxOutputTokens": 2048 + } + ] + } +] +``` + +Notes: BYOK powers the chat/agent experience only (inline completions and embeddings still require a +GitHub account). On CPU, prefer a smaller model and a modest context window — the server emits SSE +heartbeats so a long prompt prefill does not trip the client's stream-inactivity timeout. Agent-mode +tool calling depends on the model's own tool-calling quality. Pass `--api-key` (or +`OpenAiServerConfig.apiKey(...)`) to require an `Authorization: Bearer` token; the server binds to +`127.0.0.1` by default. + +#### Option B — `LlamaServer` (NanoHTTPD, fat-jar `Main-Class`) + The fat jar built by the `assembly` profile (`mvn -P assembly package`) is runnable: its `Main-Class` is `net.ladenthin.llama.server.LlamaServer`, a small [NanoHTTPD](https://github.com/NanoHttpd/nanohttpd) server that loads a GGUF model in-process and serves OpenAI-compatible endpoints by forwarding each diff --git a/TODO.md b/TODO.md index 635a2d41..f195daa1 100644 --- a/TODO.md +++ b/TODO.md @@ -13,6 +13,56 @@ cross-cutting initiative. ## Open — jllama-specific +### ⚠️ OpenAI server: TWO implementations to consolidate ("best of both") + +Two independent, Claude-generated OpenAI-compatible servers now coexist in +`net.ladenthin.llama.server` after PR #240 was merged on top of the NanoHTTPD server that landed +via #242. **This is a temporary state**; one unified implementation must be chosen. Until then both +compile and are tested side by side. + +| | **Option A — `OpenAiCompatServer`** (from PR #240) | **Option B — `LlamaServer`** (from #242) | +|---|---|---| +| HTTP layer | JDK `com.sun.net.httpserver` (the supported `jdk.httpserver` module — **no dependency**) | NanoHTTPD (`` dep, bundled only in fat jar) | +| Streaming | **Yes** — SSE with `delta.tool_calls`, heartbeats during prefill | No — blocking, full JSON per request | +| Routes | `POST /v1/chat/completions`, `GET /v1/models` | `POST /v1/chat/completions`, `/v1/completions`, `/v1/embeddings`, `GET /v1/models`, `GET /health` | +| Entry point | CLI launcher + embeddable; `OpenAiServerConfig` builder; optional bearer auth; binds `127.0.0.1` | fat-jar `Main-Class`; `LlamaServerArgs` CLI (`--host/--port/--ctx-size/--threads/…`) | +| Native path | `requestChatCompletionStream` / `receiveChatCompletionChunk` (+ `wrap_stream_chunk` C++ helper) | `LlamaModel.handle*` (blocking) | +| Tests | mapper/SSE/parser unit tests + model-free HTTP test over a socket (`ChatBackend` seam) | `OaiRouterTest`, `LlamaServerArgsTest`, `OaiHttpServerIntegrationTest` | + +**Important cross-insight:** Option B's own follow-up TODO below ("OpenAI-compatible server: token +streaming (SSE) + Java-8 HTTP layer") lists SSE as *the main functional gap* and says to **avoid** +`com.sun.net.httpserver` because it is "ArchUnit-banned". Option A **already implements that SSE +streaming** with `com.sun.net.httpserver`, and the ban was lifted correctly: `com.sun.net.httpserver` +is a *supported, exported* JDK API (the `jdk.httpserver` module), not an internal `com.sun..` package — +the `noInternalJdkImports` ArchUnit rule now carries an explicit exception for it. So the premise that +blocked the JDK approach on Option B's side does not hold. + +**Consolidation task (separate session — a kickoff prompt accompanies this change):** go through both +implementations, take the best of each, settle on ONE server, delete the other, reconcile the +dependency (`pom.xml` NanoHTTPD + assembly), the ArchUnit `layeredArchitecture` `Server` layer, the +`spotbugs-exclude.xml` entries, `package-info.java`, the README "OpenAI-compatible HTTP server" +section, and this TODO (including the now-partly-moot SSE section below). + +### OpenAI-compatible HTTP endpoint (shipped; follow-ups open) + +`net.ladenthin.llama.server.OpenAiCompatServer` exposes `POST /v1/chat/completions` (streaming via +SSE + non-streaming) and `GET /v1/models` over the JDK's built-in `com.sun.net.httpserver` (no new +dependency), so editors that speak the OpenAI protocol (e.g. VS Code Copilot "Custom Endpoint") can +drive a local model. Streaming uses the native OAI chunk path (`requestChatCompletionStream` / +`receiveChatCompletionChunk`), preserving `delta.tool_calls` for agent mode. Follow-ups, deferred +until requested: + +- **Multi-model registry.** Only one model id is advertised/served today; support several models + chosen by the request `model` field (and listed in `/v1/models`). +- **`stream_options.include_usage` passthrough** so the final streamed `usage` chunk is emitted + (needs a generic raw-param passthrough on `InferenceParameters`, or explicit mapping). +- **Additional `apiType`s.** VS Code "Custom Endpoint" also offers Anthropic `messages` and OpenAI + `responses`; only `chat-completions` is implemented. Also consider `/v1/completions` and + `/v1/embeddings` routes. +- **Gemma 4 tool-calling validation.** Confirm the pinned llama.cpp (`b9682`) includes the Gemma 4 + tool-call parser fixes (landed upstream ~Apr 2026); if not, bump per the upgrade procedure so + streamed/blocking `tool_calls` come through for Gemma 4 GGUFs. + ### llama.cpp upstream feature exposure (queued, deferred by policy) These are JNI plumbing items for upstream API additions. Policy: add only after a real user request — they are mostly relevant to specific model families or specialized workflows. diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 0836ea32..ad499e20 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include // We store some references to Java classes and their fields/methods here to speed up things for later and to fail // early on if anything can't be found. This happens when the JVM loads the shared library (see `JNI_OnLoad`). @@ -111,32 +111,27 @@ jobject o_log_callback = nullptr; // within each table only matters for the human reader. // --------------------------------------------------------------------------- static jclass *const g_global_class_refs[] = { - &c_llama_model, &c_string, &c_hash_map, &c_map, &c_set, - &c_entry, &c_iterator, &c_integer, &c_float, &c_biconsumer, - &c_llama_error, &c_log_level, &c_log_format, &c_error_oom, + &c_llama_model, &c_string, &c_hash_map, &c_map, &c_set, &c_entry, &c_iterator, + &c_integer, &c_float, &c_biconsumer, &c_llama_error, &c_log_level, &c_log_format, &c_error_oom, }; static jobject *const g_global_object_refs[] = { - &o_utf_8, - &o_log_level_debug, &o_log_level_info, &o_log_level_warn, &o_log_level_error, - &o_log_format_json, &o_log_format_text, + &o_utf_8, &o_log_level_debug, &o_log_level_info, &o_log_level_warn, + &o_log_level_error, &o_log_format_json, &o_log_format_text, }; // Maps every object that is fetched from a Java static field on load to the // (class, field) pair it should be looked up from. struct static_object_binding { - jobject *target; - jclass *cls; + jobject *target; + jclass *cls; jfieldID *field; }; static const static_object_binding g_static_object_bindings[] = { - {&o_log_level_debug, &c_log_level, &f_log_level_debug}, - {&o_log_level_info, &c_log_level, &f_log_level_info}, - {&o_log_level_warn, &c_log_level, &f_log_level_warn}, - {&o_log_level_error, &c_log_level, &f_log_level_error}, - {&o_log_format_json, &c_log_format, &f_log_format_json}, - {&o_log_format_text, &c_log_format, &f_log_format_text}, + {&o_log_level_debug, &c_log_level, &f_log_level_debug}, {&o_log_level_info, &c_log_level, &f_log_level_info}, + {&o_log_level_warn, &c_log_level, &f_log_level_warn}, {&o_log_level_error, &c_log_level, &f_log_level_error}, + {&o_log_format_json, &c_log_format, &f_log_format_json}, {&o_log_format_text, &c_log_format, &f_log_format_text}, }; /** @@ -160,11 +155,9 @@ static void throw_invalid_request(JNIEnv *env, const std::exception &e) { * Returns true if result is non-null and not an error. * On failure throws via JNI and returns false. Callers must return immediately. */ -[[nodiscard]] static bool result_ok_or_throw(JNIEnv *env, - const server_task_result_ptr &result) { +[[nodiscard]] static bool result_ok_or_throw(JNIEnv *env, const server_task_result_ptr &result) { if (!result || result->is_error()) { - env->ThrowNew(c_llama_error, - result ? get_result_error_message(result).c_str() : "No result"); + env->ThrowNew(c_llama_error, result ? get_result_error_message(result).c_str() : "No result"); return false; } return true; @@ -174,9 +167,7 @@ static void throw_invalid_request(JNIEnv *env, const std::exception &e) { * Returns true if the batch completed without a task-level error. * On failure throws via JNI and returns false. Callers must return immediately. */ -[[nodiscard]] static bool batch_ok_or_throw( - JNIEnv *env, - const server_response_reader::batch_response &br) { +[[nodiscard]] static bool batch_ok_or_throw(JNIEnv *env, const server_response_reader::batch_response &br) { if (br.error) { env->ThrowNew(c_llama_error, get_result_error_message(br.error).c_str()); return false; @@ -189,10 +180,7 @@ static void throw_invalid_request(JNIEnv *env, const std::exception &e) { * write the result into `out`. Returns true on success; on failure throws and * returns false. */ -[[nodiscard]] static bool parse_oai_chat_params(JNIEnv *env, - server_context *ctx_server, - json &body, - json &out) { +[[nodiscard]] static bool parse_oai_chat_params(JNIEnv *env, server_context *ctx_server, json &body, json &out) { try { std::vector files; auto meta = ctx_server->get_meta(); @@ -206,32 +194,24 @@ static void throw_invalid_request(JNIEnv *env, const std::exception &e) { // Tokenise the prompt in `data` and fill task.tokens + task.params. // Callers must wrap this in try/catch (params_from_json_cmpl can throw). -static void populate_completion_task(server_task &task, - jllama_context *jctx, - int n_ctx_slot, - const std::vector &logit_bias_eog, - const json &data) { - auto tokenized_prompts = tokenize_input_prompts( - jctx->vocab, nullptr, data.at("prompt"), true, true); +static void populate_completion_task(server_task &task, jllama_context *jctx, int n_ctx_slot, + const std::vector &logit_bias_eog, const json &data) { + auto tokenized_prompts = tokenize_input_prompts(jctx->vocab, nullptr, data.at("prompt"), true, true); if (!tokenized_prompts.empty()) { task.tokens = std::move(tokenized_prompts[0]); } - task.params = server_task::params_from_json_cmpl( - jctx->vocab, jctx->params, n_ctx_slot, logit_bias_eog, data); + task.params = server_task::params_from_json_cmpl(jctx->vocab, jctx->params, n_ctx_slot, logit_bias_eog, data); } -[[nodiscard]] static jint dispatch_streaming_completion(JNIEnv *env, - jllama_context *jctx, - const json &data, - server_task_type task_type, - task_response_type res_type) { +[[nodiscard]] static jint dispatch_streaming_completion(JNIEnv *env, jllama_context *jctx, const json &data, + server_task_type task_type, task_response_type res_type) { server_context *ctx_server = &jctx->server; auto meta = ctx_server->get_meta(); - auto *rd = new server_response_reader(ctx_server->get_response_reader()); - int tid = rd->get_new_id(); + auto *rd = new server_response_reader(ctx_server->get_response_reader()); + int tid = rd->get_new_id(); try { server_task task(task_type); - task.id = tid; + task.id = tid; populate_completion_task(task, jctx, meta.slot_n_ctx, meta.logit_bias_eog, data); task.params.res_type = res_type; rd->post_task(std::move(task)); @@ -252,14 +232,11 @@ static void populate_completion_task(server_task &task, * handleInfill — the blocking completion path. * On error: throws via JNI and returns nullptr. */ -[[nodiscard]] static jstring dispatch_blocking_completion(JNIEnv *env, - jllama_context *jctx, - const json &data, - server_task_type task_type, - task_response_type res_type) { +[[nodiscard]] static jstring dispatch_blocking_completion(JNIEnv *env, jllama_context *jctx, const json &data, + server_task_type task_type, task_response_type res_type) { server_context *ctx_server = &jctx->server; auto meta = ctx_server->get_meta(); - auto rd = ctx_server->get_response_reader(); + auto rd = ctx_server->get_response_reader(); server_task task(task_type); task.id = rd.get_new_id(); try { @@ -271,7 +248,8 @@ static void populate_completion_task(server_task &task, task.params.res_type = res_type; rd.post_task(std::move(task)); auto br = rd.wait_for_all([] { return false; }); - if (!batch_ok_or_throw(env, br)) return nullptr; + if (!batch_ok_or_throw(env, br)) + return nullptr; return results_to_jstring_impl(env, br.results); } @@ -297,62 +275,53 @@ std::string parse_jstring(JNIEnv *env, jstring java_string) { * Combines parse_jstring + json::parse, which every parameter-taking JNI * function needs before it can read its arguments. */ -static json parse_json_params(JNIEnv *env, jstring jparams) { - return json::parse(parse_jstring(env, jparams)); -} +static json parse_json_params(JNIEnv *env, jstring jparams) { return json::parse(parse_jstring(env, jparams)); } /** * Convenience wrapper around require_json_field_impl (jni_helpers.hpp). * Returns false and throws if `field` is absent from `data`. */ -[[nodiscard]] static bool require_json_field(JNIEnv *env, const json &data, - const char *field) { +[[nodiscard]] static bool require_json_field(JNIEnv *env, const json &data, const char *field) { return require_json_field_impl(env, data, field, c_llama_error); } // Build a single indexed token task for batch submission (rerank and embedding). // Assigns the reader-allocated id; moves tokens into the task. -[[nodiscard]] static server_task build_indexed_token_task(server_response_reader &rd, - server_task_type type, - server_tokens &&tokens, - int index, - task_response_type res_type) { +[[nodiscard]] static server_task build_indexed_token_task(server_response_reader &rd, server_task_type type, + server_tokens &&tokens, int index, + task_response_type res_type) { server_task task(type); - task.id = rd.get_new_id(); - task.tokens = std::move(tokens); - task.index = index; + task.id = rd.get_new_id(); + task.tokens = std::move(tokens); + task.index = index; task.params.res_type = res_type; return task; } // Post a single pre-built task, wait for its result, and return JSON as a jstring. // The task's id field is assigned here; callers must not set it beforehand. -[[nodiscard]] static jstring dispatch_one_shot_task(JNIEnv *env, - server_context *ctx_server, - server_task task) { - auto rd = ctx_server->get_response_reader(); - task.id = rd.get_new_id(); +[[nodiscard]] static jstring dispatch_one_shot_task(JNIEnv *env, server_context *ctx_server, server_task task) { + auto rd = ctx_server->get_response_reader(); + task.id = rd.get_new_id(); rd.post_task(std::move(task)); auto result = rd.next([] { return false; }); - if (!result_ok_or_throw(env, result)) return nullptr; + if (!result_ok_or_throw(env, result)) + return nullptr; return json_to_jstring_impl(env, result->to_json()); } // Post a single slot file task (SAVE or RESTORE), wait for its result, and // return the result JSON as a jstring. -[[nodiscard]] static jstring exec_slot_file_task(JNIEnv *env, - server_context *ctx_server, - jint slotId, - jstring jfilename, - server_task_type task_type, - const char *empty_filename_error) { +[[nodiscard]] static jstring exec_slot_file_task(JNIEnv *env, server_context *ctx_server, jint slotId, + jstring jfilename, server_task_type task_type, + const char *empty_filename_error) { const std::string filename = jfilename != nullptr ? parse_jstring(env, jfilename) : ""; if (filename.empty()) { env->ThrowNew(c_llama_error, empty_filename_error); return nullptr; } server_task task(task_type); - task.slot_action.id_slot = slotId; + task.slot_action.id_slot = slotId; task.slot_action.filename = filename; task.slot_action.filepath = filename; return dispatch_one_shot_task(env, ctx_server, std::move(task)); @@ -446,9 +415,12 @@ void log_callback_trampoline(ggml_log_level level, const char *text, void *user_ // Validates the jllama_context at every JNI entry point. Declares both // `jctx` and `ctx_server` in the caller's scope; returns the given sentinel // (omit for void functions) if the model is not loaded. -#define REQUIRE_SERVER_CONTEXT(...) \ - auto *jctx = get_jllama_context(env, obj); \ - if (!jctx) { env->ThrowNew(c_llama_error, "Model is not loaded"); return __VA_ARGS__; } \ +#define REQUIRE_SERVER_CONTEXT(...) \ + auto *jctx = get_jllama_context(env, obj); \ + if (!jctx) { \ + env->ThrowNew(c_llama_error, "Model is not loaded"); \ + return __VA_ARGS__; \ + } \ server_context *ctx_server = &jctx->server /** @@ -484,9 +456,8 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { c_log_format = env->FindClass("net/ladenthin/llama/args/LogFormat"); c_error_oom = env->FindClass("java/lang/OutOfMemoryError"); - if (!(c_llama_model && c_standard_charsets && c_string && c_hash_map && c_map && - c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_log_level && - c_log_format && c_error_oom)) { + if (!(c_llama_model && c_standard_charsets && c_string && c_hash_map && c_map && c_set && c_entry && c_iterator && + c_integer && c_float && c_biconsumer && c_llama_error && c_log_level && c_log_format && c_error_oom)) { goto error; } @@ -534,8 +505,8 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { f_log_format_json = env->GetStaticFieldID(c_log_format, "JSON", "Lnet/ladenthin/llama/args/LogFormat;"); f_log_format_text = env->GetStaticFieldID(c_log_format, "TEXT", "Lnet/ladenthin/llama/args/LogFormat;"); - if (!(f_model_pointer && f_utf_8 && f_log_level_debug && f_log_level_info && - f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text)) { + if (!(f_model_pointer && f_utf_8 && f_log_level_debug && f_log_level_info && f_log_level_warn && + f_log_level_error && f_log_format_json && f_log_format_text)) { goto error; } @@ -603,8 +574,8 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { // by the single load_model_impl call. namespace { struct load_progress_ud { - JNIEnv *env; - jobject callback; + JNIEnv *env; + jobject callback; jmethodID on_progress; }; @@ -640,9 +611,9 @@ static void load_model_impl(JNIEnv *env, jobject obj, jobjectArray jparams, jobj common_init(); - auto *jctx = new jllama_context(); - jctx->vocab_only = vocab_only; - jctx->params = params; + auto *jctx = new jllama_context(); + jctx->vocab_only = vocab_only; + jctx->params = params; auto fail_load = [&](const char *msg) { if (jctx->vocab_only_model) { @@ -685,14 +656,14 @@ static void load_model_impl(JNIEnv *env, jobject obj, jobjectArray jparams, jobj load_progress_ud progress_ud{}; if (progress != nullptr) { jclass cb_cls = env->GetObjectClass(progress); - progress_ud.env = env; - progress_ud.callback = progress; + progress_ud.env = env; + progress_ud.callback = progress; progress_ud.on_progress = env->GetMethodID(cb_cls, "onProgress", "(F)Z"); if (progress_ud.on_progress == nullptr) { fail_load("LoadProgressCallback.onProgress(float) not found"); return; } - params.load_progress_callback = jni_load_progress_trampoline; + params.load_progress_callback = jni_load_progress_trampoline; params.load_progress_callback_user_data = &progress_ud; } @@ -709,9 +680,9 @@ static void load_model_impl(JNIEnv *env, jobject obj, jobjectArray jparams, jobj auto meta = jctx->server.get_meta(); LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, common_chat_templates_source(meta.chat_params.tmpls.get()).c_str(), - common_chat_format_example(meta.chat_params.tmpls.get(), - jctx->params.use_jinja, - jctx->params.default_template_kwargs).c_str()); + common_chat_format_example(meta.chat_params.tmpls.get(), jctx->params.use_jinja, + jctx->params.default_template_kwargs) + .c_str()); } jctx->worker = std::thread([jctx]() { @@ -745,8 +716,8 @@ JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_loadModel(JNIEnv *env } JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_loadModelWithProgress(JNIEnv *env, jobject obj, - jobjectArray jparams, - jobject callback) { + jobjectArray jparams, + jobject callback) { load_model_impl(env, obj, jparams, callback); } @@ -755,7 +726,7 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_getModelMetaJson(J if (jctx->vocab_only) { json meta = { {"vocab_type", llama_vocab_type(jctx->vocab)}, - {"n_vocab", llama_vocab_n_tokens(jctx->vocab)}, + {"n_vocab", llama_vocab_n_tokens(jctx->vocab)}, }; return json_to_jstring_impl(env, meta); } @@ -767,27 +738,26 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_getModelMetaJson(J llama_model_meta_val_str(mdl, "general.architecture", arch_buf, sizeof(arch_buf)); } json j = { - {"vocab_type", m.model_vocab_type}, - {"n_vocab", m.model_vocab_n_tokens}, - {"n_ctx_train", m.model_n_ctx_train}, - {"n_embd", m.model_n_embd_inp}, - {"n_params", m.model_n_params}, - {"size", m.model_size}, - {"modalities", {{"vision", m.has_inp_image}, {"audio", m.has_inp_audio}}}, - {"name", m.model_name}, + {"vocab_type", m.model_vocab_type}, + {"n_vocab", m.model_vocab_n_tokens}, + {"n_ctx_train", m.model_n_ctx_train}, + {"n_embd", m.model_n_embd_inp}, + {"n_params", m.model_n_params}, + {"size", m.model_size}, + {"modalities", {{"vision", m.has_inp_image}, {"audio", m.has_inp_audio}}}, + {"name", m.model_name}, {"architecture", std::string(arch_buf)}, }; return json_to_jstring_impl(env, j); } -JNIEXPORT jint JNICALL Java_net_ladenthin_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, jstring jparams) { +JNIEXPORT jint JNICALL Java_net_ladenthin_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, + jstring jparams) { REQUIRE_SERVER_CONTEXT(0); json data = parse_json_params(env, jparams); - const server_task_type type = is_infill_request(data) - ? SERVER_TASK_TYPE_INFILL - : SERVER_TASK_TYPE_COMPLETION; + const server_task_type type = is_infill_request(data) ? SERVER_TASK_TYPE_INFILL : SERVER_TASK_TYPE_COMPLETION; return dispatch_streaming_completion(env, jctx, data, type, TASK_RESPONSE_TYPE_NONE); } @@ -798,7 +768,7 @@ JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_releaseTask(JNIEnv *e } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_receiveCompletionJson(JNIEnv *env, jobject obj, - jint id_task) { + jint id_task) { REQUIRE_SERVER_CONTEXT(nullptr); server_response_reader *rd; @@ -839,6 +809,52 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_receiveCompletionJ return json_to_jstring_impl(env, response); } +// Streaming OpenAI chat: poll one step of a chat.completion.chunk stream. +// Returns {"data": , "stop": } — `data` is exactly +// what the streaming result's to_json() produced (a single chunk object for a +// partial token, or a JSON array of chunks for the final delta + usage). The +// uniform envelope avoids injecting a "stop" key into an array. Skips the +// header-only nullptr sentinels (upstream b9437+) and releases the reader on stop. +JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_receiveChatCompletionChunk(JNIEnv *env, jobject obj, + jint id_task) { + REQUIRE_SERVER_CONTEXT(nullptr); + + server_response_reader *rd; + { + std::lock_guard lk(jctx->readers_mutex); + auto it = jctx->readers.find(id_task); + if (it == jctx->readers.end()) { + env->ThrowNew(c_llama_error, "Task not found"); + return nullptr; + } + rd = it->second.get(); + } + + json payload; + bool stop = false; + while (true) { + server_task_result_ptr result = rd->next([] { return false; }); + + if (!result_ok_or_throw(env, result)) { + erase_reader(jctx, id_task); + return nullptr; + } + + json chunk = result->to_json(); + if (chunk.is_null()) { + continue; + } + payload = std::move(chunk); + stop = result->is_stop(); + if (stop) { + erase_reader(jctx, id_task); + } + break; + } + + return json_to_jstring_impl(env, wrap_stream_chunk(std::move(payload), stop)); +} + JNIEXPORT jfloatArray JNICALL Java_net_ladenthin_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) { REQUIRE_SERVER_CONTEXT(nullptr); @@ -850,15 +866,16 @@ JNIEXPORT jfloatArray JNICALL Java_net_ladenthin_llama_LlamaModel_embed(JNIEnv * SRV_INF("Calling embedding '%s'\n", prompt.c_str()); auto tokens = tokenize_mixed(jctx->vocab, prompt, true, true); - auto rd = ctx_server->get_response_reader(); + auto rd = ctx_server->get_response_reader(); server_task task(SERVER_TASK_TYPE_EMBEDDING); - task.id = rd.get_new_id(); + task.id = rd.get_new_id(); task.tokens = server_tokens(tokens, false); - task.index = 0; + task.index = 0; rd.post_task(std::move(task)); auto br = rd.wait_for_all([] { return false; }); - if (!batch_ok_or_throw(env, br)) return nullptr; + if (!batch_ok_or_throw(env, br)) + return nullptr; auto *embd_result = dynamic_cast(br.results[0].get()); if (!embd_result || embd_result->embedding.empty() || embd_result->embedding[0].empty()) { @@ -872,14 +889,15 @@ JNIEXPORT jfloatArray JNICALL Java_net_ladenthin_llama_LlamaModel_embed(JNIEnv * } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleRerank(JNIEnv *env, jobject obj, jstring jprompt, - jobjectArray documents) { + jobjectArray documents) { REQUIRE_SERVER_CONTEXT(nullptr); { auto meta = ctx_server->get_meta(); if (!jctx->params.embedding || meta.pooling_type != LLAMA_POOLING_TYPE_RANK) { - env->ThrowNew(c_llama_error, - "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); + env->ThrowNew( + c_llama_error, + "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); return nullptr; } } @@ -888,7 +906,7 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleRerank(JNIEn const jsize amount_documents = env->GetArrayLength(documents); auto *document_array = parse_string_array(env, documents, amount_documents); - auto document_vector = std::vector(document_array, document_array + amount_documents); + auto document_vector = std::vector(document_array, document_array + amount_documents); free_string_array(document_array, amount_documents); const llama_model *model = llama_get_model(ctx_server->get_llama_context()); @@ -896,14 +914,15 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleRerank(JNIEn std::vector tasks; tasks.reserve(document_vector.size()); for (size_t i = 0; i < document_vector.size(); i++) { - tasks.push_back(build_indexed_token_task(rd, SERVER_TASK_TYPE_RERANK, - format_prompt_rerank(model, jctx->vocab, nullptr, prompt, document_vector[i]), + tasks.push_back(build_indexed_token_task( + rd, SERVER_TASK_TYPE_RERANK, format_prompt_rerank(model, jctx->vocab, nullptr, prompt, document_vector[i]), static_cast(i), TASK_RESPONSE_TYPE_NONE)); } rd.post_tasks(std::move(tasks)); auto br = rd.wait_for_all([] { return false; }); - if (!batch_ok_or_throw(env, br)) return nullptr; + if (!batch_ok_or_throw(env, br)) + return nullptr; return json_to_jstring_impl(env, rerank_results_to_json(br.results, document_vector)); } @@ -913,35 +932,53 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_applyTemplate(JNIE json data = parse_json_params(env, jparams); json templateData; - if (!parse_oai_chat_params(env, ctx_server, data, templateData)) return nullptr; + if (!parse_oai_chat_params(env, ctx_server, data, templateData)) + return nullptr; std::string tok_str = templateData.at("prompt"); return env->NewStringUTF(tok_str.c_str()); } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleChatCompletions(JNIEnv *env, jobject obj, - jstring jparams) { + jstring jparams) { REQUIRE_SERVER_CONTEXT(nullptr); json body = parse_json_params(env, jparams); json data; - if (!parse_oai_chat_params(env, ctx_server, body, data)) return nullptr; + if (!parse_oai_chat_params(env, ctx_server, body, data)) + return nullptr; - return dispatch_blocking_completion(env, jctx, data, - SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_OAI_CHAT); + return dispatch_blocking_completion(env, jctx, data, SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_OAI_CHAT); } JNIEXPORT jint JNICALL Java_net_ladenthin_llama_LlamaModel_requestChatCompletion(JNIEnv *env, jobject obj, - jstring jparams) { + jstring jparams) { REQUIRE_SERVER_CONTEXT(0); json body = parse_json_params(env, jparams); // Chat template already applied by parse_oai_chat_params; no OAI wrapping on the streaming path. json data; - if (!parse_oai_chat_params(env, ctx_server, body, data)) return 0; + if (!parse_oai_chat_params(env, ctx_server, body, data)) + return 0; - return dispatch_streaming_completion(env, jctx, data, - SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_NONE); + return dispatch_streaming_completion(env, jctx, data, SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_NONE); +} + +// Streaming OpenAI chat with OAI-formatted chunks. Mirrors requestChatCompletion +// but sets TASK_RESPONSE_TYPE_OAI_CHAT so each polled result formats as an +// OpenAI chat.completion.chunk (including streamed delta.tool_calls). The params +// must carry "stream": true so the upstream formatter emits chunk deltas; poll +// the returned task id with receiveChatCompletionChunk. +JNIEXPORT jint JNICALL Java_net_ladenthin_llama_LlamaModel_requestChatCompletionStream(JNIEnv *env, jobject obj, + jstring jparams) { + REQUIRE_SERVER_CONTEXT(0); + + json body = parse_json_params(env, jparams); + json data; + if (!parse_oai_chat_params(env, ctx_server, body, data)) + return 0; + + return dispatch_streaming_completion(env, jctx, data, SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_OAI_CHAT); } JNIEXPORT jintArray JNICALL Java_net_ladenthin_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { @@ -961,7 +998,7 @@ static std::string detokenize(jllama_context *jctx, const std::vectorSetLongField(obj, f_model_pointer, 0); @@ -1006,7 +1044,7 @@ JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_cancelCompletion(JNIE } JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_setLogger(JNIEnv *env, jclass clazz, jobject log_format, - jobject jcallback) { + jobject jcallback) { if (o_log_callback != nullptr) { env->DeleteGlobalRef(o_log_callback); } @@ -1031,7 +1069,7 @@ JNIEXPORT void JNICALL Java_net_ladenthin_llama_LlamaModel_setLogger(JNIEnv *env } JNIEXPORT jbyteArray JNICALL Java_net_ladenthin_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *env, jclass clazz, - jstring j_schema) { + jstring j_schema) { const std::string c_schema = parse_jstring(env, j_schema); nlohmann::ordered_json c_schema_json = nlohmann::ordered_json::parse(c_schema); const std::string c_grammar = json_schema_to_grammar(c_schema_json); @@ -1039,16 +1077,15 @@ JNIEXPORT jbyteArray JNICALL Java_net_ladenthin_llama_LlamaModel_jsonSchemaToGra } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleCompletions(JNIEnv *env, jobject obj, - jstring jparams) { + jstring jparams) { REQUIRE_SERVER_CONTEXT(nullptr); json data = parse_json_params(env, jparams); - return dispatch_blocking_completion(env, jctx, data, - SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_NONE); + return dispatch_blocking_completion(env, jctx, data, SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_NONE); } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleCompletionsOai(JNIEnv *env, jobject obj, - jstring jparams) { + jstring jparams) { REQUIRE_SERVER_CONTEXT(nullptr); json body = parse_json_params(env, jparams); @@ -1060,8 +1097,7 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleCompletionsO return nullptr; } - return dispatch_blocking_completion(env, jctx, data, - SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_OAI_CMPL); + return dispatch_blocking_completion(env, jctx, data, SERVER_TASK_TYPE_COMPLETION, TASK_RESPONSE_TYPE_OAI_CMPL); } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleInfill(JNIEnv *env, jobject obj, jstring jparams) { @@ -1070,8 +1106,7 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleInfill(JNIEn // Check FIM token support via server_context_meta (populated from the // same llama_vocab_fim_* calls inside server-context). auto meta = ctx_server->get_meta(); - if (meta.fim_pre_token == LLAMA_TOKEN_NULL || - meta.fim_sub_token == LLAMA_TOKEN_NULL || + if (meta.fim_pre_token == LLAMA_TOKEN_NULL || meta.fim_sub_token == LLAMA_TOKEN_NULL || meta.fim_mid_token == LLAMA_TOKEN_NULL) { env->ThrowNew(c_llama_error, "Model does not support fill-in-the-middle infill"); return nullptr; @@ -1079,30 +1114,27 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleInfill(JNIEn json data = parse_json_params(env, jparams); - if (!require_json_field(env, data, "input_prefix")) return nullptr; - if (!require_json_field(env, data, "input_suffix")) return nullptr; + if (!require_json_field(env, data, "input_prefix")) + return nullptr; + if (!require_json_field(env, data, "input_suffix")) + return nullptr; json input_extra = json_value(data, "input_extra", json::array()); data["input_extra"] = input_extra; std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = - tokenize_input_prompts(jctx->vocab, nullptr, prompt, false, true); - - data["prompt"] = format_prompt_infill(jctx->vocab, - data.at("input_prefix"), data.at("input_suffix"), - data.at("input_extra"), - jctx->params.n_batch, jctx->params.n_predict, - meta.slot_n_ctx, jctx->params.spm_infill, - tokenized_prompts.empty() ? llama_tokens() - : tokenized_prompts[0].get_tokens()); - - return dispatch_blocking_completion(env, jctx, data, - SERVER_TASK_TYPE_INFILL, TASK_RESPONSE_TYPE_NONE); + std::vector tokenized_prompts = tokenize_input_prompts(jctx->vocab, nullptr, prompt, false, true); + + data["prompt"] = + format_prompt_infill(jctx->vocab, data.at("input_prefix"), data.at("input_suffix"), data.at("input_extra"), + jctx->params.n_batch, jctx->params.n_predict, meta.slot_n_ctx, jctx->params.spm_infill, + tokenized_prompts.empty() ? llama_tokens() : tokenized_prompts[0].get_tokens()); + + return dispatch_blocking_completion(env, jctx, data, SERVER_TASK_TYPE_INFILL, TASK_RESPONSE_TYPE_NONE); } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleEmbeddings(JNIEnv *env, jobject obj, - jstring jparams, jboolean joaiCompat) { + jstring jparams, jboolean joaiCompat) { REQUIRE_SERVER_CONTEXT(nullptr); if (!require_embedding_support(env, jctx->params.embedding, c_llama_error)) { @@ -1126,16 +1158,16 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleEmbeddings(J json prompt; bool use_base64 = false; try { - prompt = extract_embedding_prompt(body, force_no_oaicompat); + prompt = extract_embedding_prompt(body, force_no_oaicompat); use_base64 = parse_encoding_format(body); } catch (const std::exception &e) { env->ThrowNew(c_llama_error, e.what()); return nullptr; } - if (force_no_oaicompat) res_type = TASK_RESPONSE_TYPE_NONE; + if (force_no_oaicompat) + res_type = TASK_RESPONSE_TYPE_NONE; - std::vector tokenized_prompts = - tokenize_input_prompts(jctx->vocab, nullptr, prompt, true, true); + std::vector tokenized_prompts = tokenize_input_prompts(jctx->vocab, nullptr, prompt, true, true); for (const auto &toks : tokenized_prompts) { if (toks.get_tokens().empty()) { @@ -1149,32 +1181,34 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleEmbeddings(J tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { tasks.push_back(build_indexed_token_task(rd, SERVER_TASK_TYPE_EMBEDDING, - server_tokens(tokenized_prompts[i].get_tokens(), false), - static_cast(i), res_type)); + server_tokens(tokenized_prompts[i].get_tokens(), false), + static_cast(i), res_type)); } rd.post_tasks(std::move(tasks)); auto br = rd.wait_for_all([] { return false; }); - if (!batch_ok_or_throw(env, br)) return nullptr; + if (!batch_ok_or_throw(env, br)) + return nullptr; json responses = json::array(); for (const auto &result : br.results) { responses.push_back(result->to_json()); } json out = (res_type == TASK_RESPONSE_TYPE_OAI_EMBD) - ? format_embeddings_response_oaicompat(body, json_value(body, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), responses, use_base64) - : responses; + ? format_embeddings_response_oaicompat( + body, json_value(body, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), responses, use_base64) + : responses; return json_to_jstring_impl(env, out); } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleTokenize(JNIEnv *env, jobject obj, jstring jcontent, - jboolean jaddSpecial, - jboolean jwithPieces) { + jboolean jaddSpecial, + jboolean jwithPieces) { REQUIRE_SERVER_CONTEXT(nullptr); - const std::string content = parse_jstring(env, jcontent); - const bool add_special = jaddSpecial; - const bool with_pieces = jwithPieces; + const std::string content = parse_jstring(env, jcontent); + const bool add_special = jaddSpecial; + const bool with_pieces = jwithPieces; llama_tokens tokens = tokenize_mixed(jctx->vocab, content, add_special, true); @@ -1201,7 +1235,7 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleTokenize(JNI } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleDetokenize(JNIEnv *env, jobject obj, - jintArray jtokens) { + jintArray jtokens) { REQUIRE_SERVER_CONTEXT(nullptr); const auto tokens = jint_array_to_tokens_impl(env, jtokens); @@ -1209,20 +1243,18 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleDetokenize(J } JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleSlotAction(JNIEnv *env, jobject obj, jint action, - jint slotId, jstring jfilename) { + jint slotId, jstring jfilename) { REQUIRE_SERVER_CONTEXT(nullptr); switch (action) { case 0: // LIST — get slot info via metrics task return dispatch_one_shot_task(env, ctx_server, server_task(SERVER_TASK_TYPE_METRICS)); case 1: // SAVE - return exec_slot_file_task(env, ctx_server, slotId, jfilename, - SERVER_TASK_TYPE_SLOT_SAVE, - "Filename is required for slot save"); + return exec_slot_file_task(env, ctx_server, slotId, jfilename, SERVER_TASK_TYPE_SLOT_SAVE, + "Filename is required for slot save"); case 2: // RESTORE - return exec_slot_file_task(env, ctx_server, slotId, jfilename, - SERVER_TASK_TYPE_SLOT_RESTORE, - "Filename is required for slot restore"); + return exec_slot_file_task(env, ctx_server, slotId, jfilename, SERVER_TASK_TYPE_SLOT_RESTORE, + "Filename is required for slot restore"); case 3: { // ERASE server_task task(SERVER_TASK_TYPE_SLOT_ERASE); task.slot_action.id_slot = slotId; @@ -1235,18 +1267,18 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleSlotAction(J } JNIEXPORT jboolean JNICALL Java_net_ladenthin_llama_LlamaModel_configureParallelInference(JNIEnv *env, jobject obj, - jstring jconfig) { + jstring jconfig) { REQUIRE_SERVER_CONTEXT(JNI_FALSE); (void)obj; json config = parse_json_params(env, jconfig); std::optional slot_sim_opt; - std::optional n_threads_opt; - std::optional n_threads_batch_opt; + std::optional n_threads_opt; + std::optional n_threads_batch_opt; try { - slot_sim_opt = parse_slot_prompt_similarity(config); - n_threads_opt = parse_positive_int_config(config, "n_threads"); + slot_sim_opt = parse_slot_prompt_similarity(config); + n_threads_opt = parse_positive_int_config(config, "n_threads"); n_threads_batch_opt = parse_positive_int_config(config, "n_threads_batch"); } catch (const std::invalid_argument &e) { env->ThrowNew(c_llama_error, e.what()); @@ -1260,17 +1292,16 @@ JNIEXPORT jboolean JNICALL Java_net_ladenthin_llama_LlamaModel_configureParallel if (n_threads_opt.has_value() || n_threads_batch_opt.has_value()) { llama_context *lctx = ctx_server->get_llama_context(); if (lctx == nullptr) { - env->ThrowNew(c_llama_error, - "configureParallelInference: llama_context not available " - "(model sleeping or not loaded)"); + env->ThrowNew(c_llama_error, "configureParallelInference: llama_context not available " + "(model sleeping or not loaded)"); return JNI_FALSE; } - const int n = n_threads_opt.value_or(jctx->params.cpuparams.n_threads); + const int n = n_threads_opt.value_or(jctx->params.cpuparams.n_threads); const int nb = n_threads_batch_opt.value_or(jctx->params.cpuparams_batch.n_threads); llama_set_n_threads(lctx, n, nb); // Keep the cached params in sync so a follow-up call that supplies only // the other field reads back the value just applied, not the original. - jctx->params.cpuparams.n_threads = n; + jctx->params.cpuparams.n_threads = n; jctx->params.cpuparams_batch.n_threads = nb; } diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 7867d59b..ff93bd94 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -126,6 +126,10 @@ JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleChatCompleti */ JNIEXPORT jint JNICALL Java_net_ladenthin_llama_LlamaModel_requestChatCompletion(JNIEnv *, jobject, jstring); +JNIEXPORT jint JNICALL Java_net_ladenthin_llama_LlamaModel_requestChatCompletionStream(JNIEnv *, jobject, jstring); + +JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_receiveChatCompletionChunk(JNIEnv *, jobject, jint); + JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleCompletions(JNIEnv *, jobject, jstring); JNIEXPORT jstring JNICALL Java_net_ladenthin_llama_LlamaModel_handleCompletionsOai(JNIEnv *, jobject, jstring); diff --git a/src/main/cpp/jni_helpers.hpp b/src/main/cpp/jni_helpers.hpp index 3d9b4605..7de6e154 100644 --- a/src/main/cpp/jni_helpers.hpp +++ b/src/main/cpp/jni_helpers.hpp @@ -50,23 +50,23 @@ struct server_response_reader; // worker thread. Stored as the Java-side `ctx` (jlong) pointer. // --------------------------------------------------------------------------- struct jllama_context { - server_context server; // value member (pimpl inside) - std::thread worker; - bool vocab_only = false; + server_context server; // value member (pimpl inside) + std::thread worker; + bool vocab_only = false; std::atomic worker_ready{false}; // Cached after load_model() — valid for the lifetime of this context. - const llama_vocab *vocab = nullptr; + const llama_vocab *vocab = nullptr; // Non-null only in vocab-only mode (bypasses server_context entirely). - llama_model *vocab_only_model = nullptr; + llama_model *vocab_only_model = nullptr; // Saved copy of common_params used to load the model. // Required by server_task::params_from_json_cmpl which takes common_params&. - common_params params; + common_params params; // Per-streaming-task response readers, keyed by task id. // Guarded by readers_mutex. - std::mutex readers_mutex; + std::mutex readers_mutex; std::map> readers; }; @@ -80,9 +80,7 @@ inline void erase_reader(jllama_context *jctx, int id_task) { // Guard: throw and return false if the model was loaded without embedding // support enabled. Used by every JNI entry point that produces embeddings. -[[nodiscard]] inline bool require_embedding_support(JNIEnv *env, - bool embedding_enabled, - jclass error_class) { +[[nodiscard]] inline bool require_embedding_support(JNIEnv *env, bool embedding_enabled, jclass error_class) { if (embedding_enabled) { return true; } @@ -101,9 +99,7 @@ inline void erase_reader(jllama_context *jctx, int id_task) { // already deleted (or never fully initialised), which is a valid no-op for // a destructor-style call. // --------------------------------------------------------------------------- -[[nodiscard]] inline jllama_context *get_jllama_context_impl(JNIEnv *env, - jobject obj, - jfieldID field_id) { +[[nodiscard]] inline jllama_context *get_jllama_context_impl(JNIEnv *env, jobject obj, jfieldID field_id) { const jlong handle = env->GetLongField(obj, field_id); if (handle == 0) { return nullptr; @@ -117,10 +113,8 @@ inline void erase_reader(jllama_context *jctx, int id_task) { // Checks that `data` contains the given key. Returns true if present. // On missing key: throws " is required" via JNI and returns false. // --------------------------------------------------------------------------- -[[nodiscard]] inline bool require_json_field_impl(JNIEnv *env, - const nlohmann::json &data, - const char *field, - jclass error_class) { +[[nodiscard]] inline bool require_json_field_impl(JNIEnv *env, const nlohmann::json &data, const char *field, + jclass error_class) { if (data.contains(field)) { return true; } @@ -135,10 +129,9 @@ inline void erase_reader(jllama_context *jctx, int id_task) { // Reads a Java int array into a std::vector and releases the JNI // array elements with JNI_ABORT (read-only — no writeback needed). // --------------------------------------------------------------------------- -[[nodiscard]] inline std::vector jint_array_to_tokens_impl( - JNIEnv *env, jintArray array) { +[[nodiscard]] inline std::vector jint_array_to_tokens_impl(JNIEnv *env, jintArray array) { const jsize length = env->GetArrayLength(array); - jint *elements = env->GetIntArrayElements(array, nullptr); + jint *elements = env->GetIntArrayElements(array, nullptr); std::vector tokens(elements, elements + length); env->ReleaseIntArrayElements(array, elements, JNI_ABORT); return tokens; @@ -170,9 +163,7 @@ inline void erase_reader(jllama_context *jctx, int id_task) { // construction to results_to_json (json_helpers.hpp) and serialisation to // json_to_jstring_impl. // --------------------------------------------------------------------------- -[[nodiscard]] inline jstring results_to_jstring_impl( - JNIEnv *env, - const std::vector &results) { +[[nodiscard]] inline jstring results_to_jstring_impl(JNIEnv *env, const std::vector &results) { return json_to_jstring_impl(env, results_to_json(results)); } @@ -184,13 +175,9 @@ inline void erase_reader(jllama_context *jctx, int id_task) { // On allocation failure: throws via JNI with oom_class and returns nullptr. // --------------------------------------------------------------------------- template -[[nodiscard]] inline JArray vec_to_jarray_impl( - JNIEnv *env, - const std::vector &values, - jclass oom_class, - const char *oom_msg, - JArray (JNIEnv_::*alloc)(jsize), - void (JNIEnv_::*copy)(JArray, jsize, jsize, const JElem *)) { +[[nodiscard]] inline JArray vec_to_jarray_impl(JNIEnv *env, const std::vector &values, jclass oom_class, + const char *oom_msg, JArray (JNIEnv_::*alloc)(jsize), + void (JNIEnv_::*copy)(JArray, jsize, jsize, const JElem *)) { const jsize len = static_cast(values.size()); JArray arr = (env->*alloc)(len); if (arr == nullptr) { @@ -202,21 +189,15 @@ template } // Converts a float vector to a Java jfloatArray. -[[nodiscard]] inline jfloatArray embedding_to_jfloat_array_impl( - JNIEnv *env, - const std::vector &values, - jclass oom_class) { - return vec_to_jarray_impl( - env, values, oom_class, "could not allocate embedding", - &JNIEnv_::NewFloatArray, &JNIEnv_::SetFloatArrayRegion); +[[nodiscard]] inline jfloatArray embedding_to_jfloat_array_impl(JNIEnv *env, const std::vector &values, + jclass oom_class) { + return vec_to_jarray_impl(env, values, oom_class, "could not allocate embedding", + &JNIEnv_::NewFloatArray, &JNIEnv_::SetFloatArrayRegion); } // Converts a token vector to a Java jintArray. -[[nodiscard]] inline jintArray tokens_to_jint_array_impl( - JNIEnv *env, - const std::vector &tokens, - jclass oom_class) { - return vec_to_jarray_impl( - env, tokens, oom_class, "could not allocate token memory", - &JNIEnv_::NewIntArray, &JNIEnv_::SetIntArrayRegion); +[[nodiscard]] inline jintArray tokens_to_jint_array_impl(JNIEnv *env, const std::vector &tokens, + jclass oom_class) { + return vec_to_jarray_impl(env, tokens, oom_class, "could not allocate token memory", + &JNIEnv_::NewIntArray, &JNIEnv_::SetIntArrayRegion); } diff --git a/src/main/cpp/json_helpers.hpp b/src/main/cpp/json_helpers.hpp index 63788dd1..6d460710 100644 --- a/src/main/cpp/json_helpers.hpp +++ b/src/main/cpp/json_helpers.hpp @@ -32,6 +32,7 @@ // 6. is_infill_request — used by nothing above it // 7. parse_slot_prompt_similarity — used by nothing above it // 8. parse_positive_int_config — used by nothing above it +// 9. wrap_stream_chunk — used by nothing above it #include "nlohmann/json.hpp" @@ -50,8 +51,7 @@ // jni_helpers.hpp, and directly in receiveCompletionJson, embed, and // handleRerank in jllama.cpp. // --------------------------------------------------------------------------- -[[nodiscard]] inline std::string get_result_error_message( - const server_task_result_ptr &result) { +[[nodiscard]] inline std::string get_result_error_message(const server_task_result_ptr &result) { return result->to_json()["message"].get(); } @@ -67,8 +67,7 @@ // This mirrors the OpenAI API convention used by handleCompletions, // handleCompletionsOai, handleChatCompletions, and handleInfill. // --------------------------------------------------------------------------- -[[nodiscard]] inline json results_to_json( - const std::vector &results) { +[[nodiscard]] inline json results_to_json(const std::vector &results) { if (results.size() == 1) { return results[0]->to_json(); } @@ -86,19 +85,14 @@ // Each element contains the original document text (looked up via the // result's "index" field), the index, and the relevance score. // --------------------------------------------------------------------------- -[[nodiscard]] inline json rerank_results_to_json( - const std::vector &results, - const std::vector &documents) { +[[nodiscard]] inline json rerank_results_to_json(const std::vector &results, + const std::vector &documents) { json arr = json::array(); for (const auto &result : results) { const auto out = result->to_json(); - int index = out["index"].get(); + int index = out["index"].get(); float score = out["score"].get(); - arr.push_back({ - {"document", documents[index]}, - {"index", index}, - {"score", score} - }); + arr.push_back({{"document", documents[index]}, {"index", index}, {"score", score}}); } return arr; } @@ -118,8 +112,12 @@ return false; } const std::string format = body.at("encoding_format").get(); - if (format == "base64") { return true; } - if (format == "float") { return false; } + if (format == "base64") { + return true; + } + if (format == "float") { + return false; + } throw std::invalid_argument("encoding_format must be \"float\" or \"base64\""); } @@ -134,8 +132,7 @@ // when "content" was used — the caller must downgrade oaicompat to NONE. // Throws std::invalid_argument if neither "input" nor "content" is present. // --------------------------------------------------------------------------- -[[nodiscard]] inline json extract_embedding_prompt(const json &body, - bool &force_no_oaicompat) { +[[nodiscard]] inline json extract_embedding_prompt(const json &body, bool &force_no_oaicompat) { force_no_oaicompat = false; if (body.count("input") != 0) { return body.at("input"); @@ -167,8 +164,7 @@ // Returns float — validated value in [0.0, 1.0]. // Throws std::invalid_argument — present but outside [0.0, 1.0]. // --------------------------------------------------------------------------- -[[nodiscard]] inline std::optional -parse_slot_prompt_similarity(const json &config) { +[[nodiscard]] inline std::optional parse_slot_prompt_similarity(const json &config) { if (!config.contains("slot_prompt_similarity")) { return std::nullopt; } @@ -188,8 +184,7 @@ parse_slot_prompt_similarity(const json &config) { // Returns int — validated value > 0. // Throws std::invalid_argument(" must be greater than 0") — present but ≤ 0. // --------------------------------------------------------------------------- -[[nodiscard]] inline std::optional -parse_positive_int_config(const json &config, const char *key) { +[[nodiscard]] inline std::optional parse_positive_int_config(const json &config, const char *key) { if (!config.contains(key)) { return std::nullopt; } @@ -199,3 +194,24 @@ parse_positive_int_config(const json &config, const char *key) { } return v; } + +// --------------------------------------------------------------------------- +// wrap_stream_chunk +// +// Wraps one streaming chat result payload together with its stop flag into a +// single transport object so the Java side has a uniform shape to parse: +// +// {"data": , "stop": } +// +// `payload` is whatever a streaming OAI chat result's to_json() produced — a +// single chat.completion.chunk object for a partial token, or a JSON array of +// chunk objects for the final result (final delta + optional usage chunk). +// The Java consumer reads "stop" and emits each element of "data" as its own +// SSE `data:` event. Used by receiveChatCompletionChunk in jllama.cpp. +// --------------------------------------------------------------------------- +[[nodiscard]] inline json wrap_stream_chunk(json payload, bool stop) { + json out; + out["data"] = std::move(payload); + out["stop"] = stop; + return out; +} diff --git a/src/main/cpp/log_helpers.hpp b/src/main/cpp/log_helpers.hpp index 8e38fd85..7c37ab9e 100644 --- a/src/main/cpp/log_helpers.hpp +++ b/src/main/cpp/log_helpers.hpp @@ -20,21 +20,24 @@ // fall-through to mirror llama.cpp's own log routing. [[nodiscard]] inline const char *log_level_name(ggml_log_level level) { switch (level) { - case GGML_LOG_LEVEL_ERROR: return "ERROR"; - case GGML_LOG_LEVEL_WARN: return "WARN"; - case GGML_LOG_LEVEL_DEBUG: return "DEBUG"; + case GGML_LOG_LEVEL_ERROR: + return "ERROR"; + case GGML_LOG_LEVEL_WARN: + return "WARN"; + case GGML_LOG_LEVEL_DEBUG: + return "DEBUG"; case GGML_LOG_LEVEL_INFO: - default: return "INFO"; + default: + return "INFO"; } } // Pure variant taking an explicit timestamp so tests are deterministic. -[[nodiscard]] inline std::string format_log_as_json( - ggml_log_level level, const char *text, std::time_t timestamp) { +[[nodiscard]] inline std::string format_log_as_json(ggml_log_level level, const char *text, std::time_t timestamp) { nlohmann::json log_obj = { {"timestamp", timestamp}, - {"level", log_level_name(level)}, - {"message", text ? text : ""}, + {"level", log_level_name(level)}, + {"message", text ? text : ""}, }; return log_obj.dump(); } diff --git a/src/main/java/module-info.java b/src/main/java/module-info.java index af93f9af..b73f4a0b 100644 --- a/src/main/java/module-info.java +++ b/src/main/java/module-info.java @@ -46,12 +46,18 @@ // the @lombok.Generated annotation carried on generated members has CLASS retention. requires static lombok; + // The OpenAI-compatible endpoint (net.ladenthin.llama.server) uses the JDK's built-in + // com.sun.net.httpserver, so module-path consumers need to read jdk.httpserver. It is a + // platform module (always present in the JDK), not an external dependency. + requires jdk.httpserver; + exports net.ladenthin.llama; exports net.ladenthin.llama.args; exports net.ladenthin.llama.callback; exports net.ladenthin.llama.exception; exports net.ladenthin.llama.json; exports net.ladenthin.llama.parameters; + exports net.ladenthin.llama.server; exports net.ladenthin.llama.value; // net.ladenthin.llama.loader is intentionally NOT exported: native-library loading, // OS detection and process/system-property infrastructure are internal to the module. diff --git a/src/main/java/net/ladenthin/llama/LlamaModel.java b/src/main/java/net/ladenthin/llama/LlamaModel.java index 3e958b03..a298ca8e 100644 --- a/src/main/java/net/ladenthin/llama/LlamaModel.java +++ b/src/main/java/net/ladenthin/llama/LlamaModel.java @@ -13,6 +13,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; +import java.util.function.Consumer; import lombok.ToString; import net.ladenthin.llama.args.LogFormat; import net.ladenthin.llama.callback.CancellationToken; @@ -20,6 +21,7 @@ import net.ladenthin.llama.callback.ToolHandler; import net.ladenthin.llama.exception.LlamaException; import net.ladenthin.llama.json.ChatResponseParser; +import net.ladenthin.llama.json.ChatStreamChunkParser; import net.ladenthin.llama.json.CompletionResponseParser; import net.ladenthin.llama.json.RerankResponseParser; import net.ladenthin.llama.loader.LlamaLoader; @@ -74,6 +76,7 @@ public class LlamaModel implements AutoCloseable { private final CompletionResponseParser completionParser = new CompletionResponseParser(); private final ChatResponseParser chatParser = new ChatResponseParser(); private final RerankResponseParser rerankParser = new RerankResponseParser(); + private final ChatStreamChunkParser chatStreamParser = new ChatStreamChunkParser(); /** * Load with the given {@link net.ladenthin.llama.parameters.ModelParameters}. Make sure to either set @@ -636,6 +639,47 @@ public LlamaIterable generateChat(InferenceParameters parameters) { return new LlamaIterable(new LlamaIterator(this, parameters, true)); } + /** + * Stream an OpenAI-compatible chat completion as {@code chat.completion.chunk} JSON objects, + * feeding each chunk's JSON string to {@code chunkSink} as it is produced. + *

+ * Unlike {@link #generateChat(InferenceParameters)} (which yields raw token text), this method + * routes through the native OpenAI streaming formatter, so each emitted chunk is a ready-to-send + * OpenAI streaming event — including streamed {@code delta.tool_calls} when the model issues a + * tool call. The final chunk carries a non-null {@code finish_reason} and, when the request set + * {@code stream_options.include_usage}, a trailing usage chunk. This is the building block for an + * OpenAI-compatible HTTP endpoint (Server-Sent Events): forward each chunk verbatim as one + * {@code data:} line and emit {@code data: [DONE]} after this method returns. + *

+ * The {@code "messages"} array (and any {@code "tools"}/{@code "tool_choice"}) is forwarded + * verbatim to the native chat-template parser. Streaming is forced on regardless of the + * {@code stream} flag in {@code parameters}. If {@code chunkSink} throws, the in-flight native + * task is cancelled and the exception propagates to the caller. + * + * @param parameters the inference parameters including messages (and optional tools) + * @param chunkSink receiver for each {@code chat.completion.chunk} JSON string, in order + * @throws net.ladenthin.llama.exception.LlamaException if inference fails + */ + public void streamChatCompletion(InferenceParameters parameters, Consumer chunkSink) { + InferenceParameters streaming = parameters.withStream(true); + int taskId = requestChatCompletionStream(streaming.toString()); + boolean stopped = false; + try { + while (!stopped) { + String envelope = receiveChatCompletionChunk(taskId); + stopped = chatStreamParser.feed(envelope, chunkSink); + } + } finally { + // On a clean stop the native reader was already released when the final chunk was + // delivered; this best-effort cancel covers an early exit (e.g. chunkSink threw) so the + // native task/slot is not leaked. Safe here because we are not concurrently inside + // receiveChatCompletionChunk, and cancelling an already-finished task is a no-op. + if (!stopped) { + cancelCompletion(taskId); + } + } + } + /** * Run a blocking completion and return the full result as a JSON string. * This is the JSON-in/JSON-out equivalent of {@link #complete(InferenceParameters)}. @@ -851,4 +895,8 @@ public String restoreSlot(int slotId, String filepath) { public native String handleChatCompletions(String params); native int requestChatCompletion(String params); + + native int requestChatCompletionStream(String params); + + native String receiveChatCompletionChunk(int taskId); } diff --git a/src/main/java/net/ladenthin/llama/json/ChatStreamChunkParser.java b/src/main/java/net/ladenthin/llama/json/ChatStreamChunkParser.java new file mode 100644 index 00000000..a7f12c98 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/json/ChatStreamChunkParser.java @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// SPDX-FileCopyrightText: 2023-2025 Konstantin Herud +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.function.Consumer; + +/** + * Pure JSON transform for the streaming-chat envelope produced by the native + * {@code receiveChatCompletionChunk} method. + * + *

The native side wraps each polled streaming result in a uniform envelope so the + * Java side never has to distinguish a single chunk object from the final chunk array: + *

{@code
+ * { "data": , "stop":  }
+ * }
+ * + *

{@code data} is exactly what an OpenAI streaming chat result's {@code to_json()} + * produced: a single {@code chat.completion.chunk} object for a partial token, or a JSON + * array of chunk objects for the final step (final delta chunk plus an optional usage + * chunk). This parser emits each chunk as its own JSON string so a caller can forward it + * verbatim as one SSE {@code data:} event. + * + *

Stateless and free of JNI / native / model dependencies — testable with JSON string + * literals alone (see {@code ChatStreamChunkParserTest}). This is the Java analogue of + * {@code wrap_stream_chunk} in {@code json_helpers.hpp}. + */ +public class ChatStreamChunkParser { + + /** Creates a new {@link ChatStreamChunkParser}. */ + public ChatStreamChunkParser() {} + + /** Shared Jackson mapper; thread-safe and reused across all instances. */ + public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** + * Parse one streaming envelope and feed each contained {@code chat.completion.chunk} + * JSON string to {@code chunkSink}, in order. + * + *

A {@code data} array yields one {@code chunkSink} call per element; a {@code data} + * object yields a single call; any other shape (absent/null) yields no calls. An + * unparseable envelope is treated as end-of-stream (returns {@code true}) so a polling + * loop cannot spin forever on malformed input. + * + * @param envelopeJson the raw {@code {"data":…,"stop":…}} string from the native layer + * @param chunkSink receiver for each chunk's JSON string (one OpenAI SSE event each) + * @return {@code true} if this envelope marks the end of the stream, else {@code false} + */ + public boolean feed(String envelopeJson, Consumer chunkSink) { + final JsonNode root; + try { + root = OBJECT_MAPPER.readTree(envelopeJson); + } catch (IOException e) { + return true; + } + JsonNode data = root.path("data"); + if (data.isArray()) { + for (JsonNode element : data) { + chunkSink.accept(element.toString()); + } + } else if (data.isObject()) { + chunkSink.accept(data.toString()); + } + return root.path("stop").asBoolean(false); + } +} diff --git a/src/main/java/net/ladenthin/llama/server/ChatBackend.java b/src/main/java/net/ladenthin/llama/server/ChatBackend.java new file mode 100644 index 00000000..cd16a233 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/server/ChatBackend.java @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; + +/** + * The chat engine seam behind {@link OpenAiCompatServer}. + * + *

Decoupling the HTTP layer from {@link net.ladenthin.llama.LlamaModel} lets the whole server — + * routing, authentication, Server-Sent-Events framing, heartbeats — be exercised by tests with a fake + * backend, with no native library and no model loaded. The production implementation is + * {@link LlamaModelChatBackend}. + * + *

Both methods receive the parsed OpenAI request object (already validated as JSON by the handler). + */ +interface ChatBackend { + + /** + * Run a non-streaming chat completion. + * + * @param request the parsed OpenAI {@code /v1/chat/completions} request + * @return the complete OpenAI {@code chat.completion} response serialized as JSON + * @throws IOException if generation fails in a way the caller should surface as a server error + */ + String complete(JsonNode request) throws IOException; + + /** + * Run a streaming chat completion, delivering each {@code chat.completion.chunk} to {@code sink} + * in order. Implementations must not emit the terminating {@code [DONE]} marker; the caller adds it. + * + * @param request the parsed OpenAI {@code /v1/chat/completions} request + * @param sink receiver for each streamed chunk's JSON + * @throws IOException if a chunk cannot be delivered or generation fails + */ + void stream(JsonNode request, ChunkSink sink) throws IOException; +} diff --git a/src/main/java/net/ladenthin/llama/server/ChunkSink.java b/src/main/java/net/ladenthin/llama/server/ChunkSink.java new file mode 100644 index 00000000..375bea80 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/server/ChunkSink.java @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import java.io.IOException; + +/** + * Receiver for the individual {@code chat.completion.chunk} JSON strings produced while a streaming + * chat completion runs. + * + *

Distinct from {@link java.util.function.Consumer} because writing a chunk to an HTTP response can + * fail with {@link IOException} (for example when the client disconnects); a checked exception lets that + * failure propagate so the in-flight generation can be cancelled. + */ +@FunctionalInterface +interface ChunkSink { + + /** + * Accept one streaming chunk's JSON text. + * + * @param chunkJson a single {@code chat.completion.chunk} object serialized as JSON + * @throws IOException if the chunk cannot be delivered (e.g. the client closed the connection) + */ + void accept(String chunkJson) throws IOException; +} diff --git a/src/main/java/net/ladenthin/llama/server/LlamaModelChatBackend.java b/src/main/java/net/ladenthin/llama/server/LlamaModelChatBackend.java new file mode 100644 index 00000000..3d418a8b --- /dev/null +++ b/src/main/java/net/ladenthin/llama/server/LlamaModelChatBackend.java @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; +import net.ladenthin.llama.LlamaModel; +import net.ladenthin.llama.parameters.InferenceParameters; + +/** + * Production {@link ChatBackend} that runs requests against a loaded {@link LlamaModel}. + * + *

Non-streaming requests reuse {@link LlamaModel#chatComplete(InferenceParameters)}, whose return + * value is already a verbatim OpenAI {@code chat.completion} body. Streaming requests use + * {@link LlamaModel#streamChatCompletion(InferenceParameters, java.util.function.Consumer)}, which + * emits OpenAI {@code chat.completion.chunk} objects (including {@code delta.tool_calls}). + * + *

The streaming sink may fail with {@link IOException} (client disconnect); because the underlying + * model API takes a {@link java.util.function.Consumer} (no checked exceptions), that failure is + * relayed across the boundary via {@link java.io.UncheckedIOException} and unwrapped here so the + * in-flight native task is cancelled. + */ +final class LlamaModelChatBackend implements ChatBackend { + + private final LlamaModel model; + private final OpenAiRequestMapper mapper; + + /** + * Create a backend over the given model. + * + * @param model the loaded model to run completions against + * @param mapper the OpenAI-request to {@link InferenceParameters} mapper + */ + LlamaModelChatBackend(LlamaModel model, OpenAiRequestMapper mapper) { + this.model = model; + this.mapper = mapper; + } + + @Override + public String complete(JsonNode request) { + return model.chatComplete(mapper.toInferenceParameters(request)); + } + + @Override + public void stream(JsonNode request, ChunkSink sink) throws IOException { + InferenceParameters params = mapper.toInferenceParameters(request); + // Holds an IOException thrown by the sink so it can be rethrown after the model API (which + // only understands unchecked exceptions) unwinds and cancels the native task. + final IOException[] sinkFailure = new IOException[1]; + try { + model.streamChatCompletion(params, chunkJson -> { + try { + sink.accept(chunkJson); + } catch (IOException e) { + sinkFailure[0] = e; + throw new java.io.UncheckedIOException(e); + } + }); + } catch (java.io.UncheckedIOException e) { + IOException cause = sinkFailure[0]; + if (cause != null) { + throw cause; + } + throw e; + } + } +} diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiCompatServer.java b/src/main/java/net/ladenthin/llama/server/OpenAiCompatServer.java new file mode 100644 index 00000000..ac5d6178 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/server/OpenAiCompatServer.java @@ -0,0 +1,457 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpServer; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import net.ladenthin.llama.LlamaModel; +import net.ladenthin.llama.parameters.ModelParameters; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A minimal OpenAI-compatible HTTP endpoint over a loaded {@link LlamaModel}, built only on the JDK's + * {@code com.sun.net.httpserver.HttpServer} (no new runtime dependency). + * + *

Routes: + *

    + *
  • {@code POST /v1/chat/completions} — streaming (Server-Sent Events) and non-streaming chat + * completions, forwarded faithfully (messages/tools verbatim; streamed {@code delta.tool_calls} + * preserved).
  • + *
  • {@code GET /v1/models} — advertises the single configured model.
  • + *
+ * + *

During streaming, the server emits SSE comment heartbeats on a timer so a long prompt prefill on + * CPU does not trip a client's stream-inactivity timeout before the first token. It binds to loopback by + * default and can require a bearer API key. The endpoint is a pass-through: tools are provided and + * executed by the client, not here. + * + *

Typical use: + *

{@code
+ * try (LlamaModel model = new LlamaModel(new ModelParameters().setModel("models/model.gguf"));
+ *      OpenAiCompatServer server = new OpenAiCompatServer(
+ *              model, OpenAiServerConfig.builder().port(8080).modelId("local").build()).start()) {
+ *     Thread.currentThread().join();
+ * }
+ * }
+ */ +public final class OpenAiCompatServer implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiCompatServer.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** The chat-completions route. */ + public static final String PATH_CHAT_COMPLETIONS = "/v1/chat/completions"; + + /** The model-list route. */ + public static final String PATH_MODELS = "/v1/models"; + + private static final int HTTP_OK = 200; + private static final int HTTP_BAD_REQUEST = 400; + private static final int HTTP_UNAUTHORIZED = 401; + private static final int HTTP_NOT_FOUND = 404; + private static final int HTTP_METHOD_NOT_ALLOWED = 405; + private static final int HTTP_SERVER_ERROR = 500; + + private static final String CONTENT_TYPE_JSON = "application/json; charset=utf-8"; + private static final String CONTENT_TYPE_SSE = "text/event-stream; charset=utf-8"; + private static final String BEARER_PREFIX = "Bearer "; + private static final String ERROR_TYPE_REQUEST = "invalid_request_error"; + private static final String ERROR_TYPE_SERVER = "server_error"; + + private final OpenAiServerConfig config; + private final ChatBackend backend; + private final HttpServer http; + private final ExecutorService requestExecutor; + private final ScheduledExecutorService heartbeatExecutor; + + /** + * Create a server backed by a loaded model. + * + * @param model the model to serve completions from (owned by the caller; not closed by the server) + * @param config the server configuration + * @throws IOException if the listening socket cannot be bound + */ + public OpenAiCompatServer(LlamaModel model, OpenAiServerConfig config) throws IOException { + this(new LlamaModelChatBackend(model, new OpenAiRequestMapper()), config); + } + + /** + * Create a server backed by an arbitrary {@link ChatBackend}. Used by tests to drive the full HTTP + * surface without a native library or model. + * + * @param backend the chat engine seam + * @param config the server configuration + * @throws IOException if the listening socket cannot be bound + */ + OpenAiCompatServer(ChatBackend backend, OpenAiServerConfig config) throws IOException { + this.config = config; + this.backend = backend; + this.requestExecutor = Executors.newCachedThreadPool(namedFactory("jllama-openai-http")); + this.heartbeatExecutor = Executors.newScheduledThreadPool(1, namedFactory("jllama-openai-hb")); + this.http = HttpServer.create(new InetSocketAddress(config.getHost(), config.getPort()), 0); + http.createContext("/", this::handleNotFound); + http.createContext(PATH_MODELS, this::handleModels); + http.createContext(PATH_CHAT_COMPLETIONS, this::handleChatCompletions); + http.setExecutor(requestExecutor); + } + + /** + * Start accepting connections. + * + * @return this server, for chaining + */ + public OpenAiCompatServer start() { + http.start(); + LOG.info("OpenAI-compatible server listening on http://{}:{}", config.getHost(), getPort()); + return this; + } + + /** + * The actual bound port (useful when configured with port {@code 0} for an ephemeral port). + * + * @return the port the server is listening on + */ + public int getPort() { + return http.getAddress().getPort(); + } + + /** Stop the server and release its thread pools. The backing model is not closed. */ + @Override + public void close() { + http.stop(0); + requestExecutor.shutdownNow(); + heartbeatExecutor.shutdownNow(); + } + + // ----- handlers ----- + + private void handleChatCompletions(HttpExchange exchange) throws IOException { + try { + if (!"POST".equalsIgnoreCase(exchange.getRequestMethod())) { + sendError(exchange, HTTP_METHOD_NOT_ALLOWED, ERROR_TYPE_REQUEST, "Only POST is supported"); + return; + } + if (!authorized(exchange)) { + sendError(exchange, HTTP_UNAUTHORIZED, ERROR_TYPE_REQUEST, "Missing or invalid API key"); + return; + } + JsonNode request = readBody(exchange); + if (request == null || !request.isObject()) { + sendError(exchange, HTTP_BAD_REQUEST, ERROR_TYPE_REQUEST, "Request body must be a JSON object"); + return; + } + JsonNode messages = request.path("messages"); + if (!messages.isArray() || messages.size() == 0) { + sendError(exchange, HTTP_BAD_REQUEST, ERROR_TYPE_REQUEST, "'messages' must be a non-empty array"); + return; + } + if (request.path("stream").asBoolean(false)) { + streamChat(exchange, request); + } else { + completeChat(exchange, request); + } + } finally { + exchange.close(); + } + } + + private void completeChat(HttpExchange exchange, JsonNode request) throws IOException { + final String body; + try { + body = backend.complete(request); + } catch (IllegalArgumentException e) { + sendError(exchange, HTTP_BAD_REQUEST, ERROR_TYPE_REQUEST, message(e)); + return; + } catch (IOException | RuntimeException e) { + LOG.warn("chat completion failed", e); + sendError(exchange, HTTP_SERVER_ERROR, ERROR_TYPE_SERVER, message(e)); + return; + } + sendJson(exchange, HTTP_OK, body); + } + + private void streamChat(HttpExchange exchange, JsonNode request) throws IOException { + exchange.getResponseHeaders().set("Content-Type", CONTENT_TYPE_SSE); + exchange.getResponseHeaders().set("Cache-Control", "no-cache"); + exchange.sendResponseHeaders(HTTP_OK, 0); + final OutputStream os = exchange.getResponseBody(); + final Object writeLock = new Object(); + final ScheduledFuture heartbeat = heartbeatExecutor.scheduleAtFixedRate( + () -> writeQuietly(os, writeLock, OpenAiSseFormatter.heartbeat()), + config.getHeartbeatMillis(), + config.getHeartbeatMillis(), + TimeUnit.MILLISECONDS); + try { + backend.stream(request, chunkJson -> writeStrict(os, writeLock, OpenAiSseFormatter.sseData(chunkJson))); + writeStrict(os, writeLock, OpenAiSseFormatter.sseDone()); + } catch (IllegalArgumentException e) { + writeQuietly( + os, + writeLock, + OpenAiSseFormatter.sseData(OpenAiSseFormatter.errorJson(message(e), ERROR_TYPE_REQUEST, null))); + } catch (IOException e) { + LOG.debug("client disconnected during stream", e); + } catch (RuntimeException e) { + LOG.warn("streaming chat completion failed", e); + writeQuietly( + os, + writeLock, + OpenAiSseFormatter.sseData(OpenAiSseFormatter.errorJson(message(e), ERROR_TYPE_SERVER, null))); + } finally { + heartbeat.cancel(false); + closeQuietly(os, writeLock); + } + } + + private void handleModels(HttpExchange exchange) throws IOException { + try { + if (!"GET".equalsIgnoreCase(exchange.getRequestMethod())) { + sendError(exchange, HTTP_METHOD_NOT_ALLOWED, ERROR_TYPE_REQUEST, "Only GET is supported"); + return; + } + if (!authorized(exchange)) { + sendError(exchange, HTTP_UNAUTHORIZED, ERROR_TYPE_REQUEST, "Missing or invalid API key"); + return; + } + sendJson(exchange, HTTP_OK, OpenAiSseFormatter.modelsJson(config.getModelId())); + } finally { + exchange.close(); + } + } + + private void handleNotFound(HttpExchange exchange) throws IOException { + try { + sendError(exchange, HTTP_NOT_FOUND, ERROR_TYPE_REQUEST, "Not found: " + exchange.getRequestURI()); + } finally { + exchange.close(); + } + } + + // ----- helpers ----- + + private boolean authorized(HttpExchange exchange) { + if (!config.isAuthenticationEnabled()) { + return true; + } + String expected = config.getApiKey(); + if (expected == null) { + return true; + } + String header = exchange.getRequestHeaders().getFirst("Authorization"); + if (header == null || !header.startsWith(BEARER_PREFIX)) { + return false; + } + return expected.equals(header.substring(BEARER_PREFIX.length())); + } + + private @Nullable JsonNode readBody(HttpExchange exchange) throws IOException { + try (InputStream is = exchange.getRequestBody()) { + return OBJECT_MAPPER.readTree(is); + } catch (JsonProcessingException e) { + LOG.debug("malformed request body", e); + return null; + } + } + + private void sendJson(HttpExchange exchange, int status, String json) throws IOException { + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", CONTENT_TYPE_JSON); + exchange.sendResponseHeaders(status, bytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + } + } + + private void sendError(HttpExchange exchange, int status, String type, String message) throws IOException { + sendJson(exchange, status, OpenAiSseFormatter.errorJson(message, type, null)); + } + + /** Write under the response lock, propagating failures so a streaming generation can be cancelled. */ + private void writeStrict(OutputStream os, Object writeLock, String text) throws IOException { + synchronized (writeLock) { + os.write(text.getBytes(StandardCharsets.UTF_8)); + os.flush(); + } + } + + /** Write under the response lock, swallowing failures (used for heartbeats and best-effort events). */ + private void writeQuietly(OutputStream os, Object writeLock, String text) { + synchronized (writeLock) { + try { + os.write(text.getBytes(StandardCharsets.UTF_8)); + os.flush(); + } catch (IOException e) { + LOG.trace("stream write failed (client likely disconnected)", e); + } + } + } + + private void closeQuietly(OutputStream os, Object writeLock) { + synchronized (writeLock) { + try { + os.close(); + } catch (IOException e) { + LOG.trace("stream close failed", e); + } + } + } + + private static String message(Throwable t) { + String m = t.getMessage(); + return m != null ? m : t.getClass().getSimpleName(); + } + + private static ThreadFactory namedFactory(String prefix) { + AtomicInteger counter = new AtomicInteger(); + return runnable -> { + Thread thread = new Thread(runnable, prefix + "-" + counter.incrementAndGet()); + thread.setDaemon(true); + return thread; + }; + } + + // ----- standalone launcher ----- + + /** + * Command-line launcher: load a GGUF model and serve it over the OpenAI-compatible endpoint. + * + *

Options: {@code --model } (required), {@code --host}, {@code --port}, + * {@code --api-key}, {@code --model-id}, {@code --ctx}, {@code --gpu-layers}, {@code --parallel}. + * + * @param args command-line options + * @throws IOException if the listening socket cannot be bound + */ + public static void main(String[] args) throws IOException { + Map opts = parseArgs(args); + String modelPath = opts.get("model"); + if (modelPath == null) { + System.err.println("Usage: OpenAiCompatServer --model [--host 127.0.0.1] [--port 8080]" + + " [--api-key KEY] [--model-id ID] [--ctx 8192] [--gpu-layers N] [--parallel N]"); + return; + } + + ModelParameters modelParams = new ModelParameters().setModel(modelPath); + OpenAiServerConfig.Builder cfg = OpenAiServerConfig.builder(); + + String host = opts.get("host"); + if (host != null) { + cfg.host(host); + } + String apiKey = opts.get("api-key"); + if (apiKey != null) { + cfg.apiKey(apiKey); + } + String modelId = opts.get("model-id"); + if (modelId != null) { + cfg.modelId(modelId); + } + + // Parse all numeric options in one place so a non-numeric value (e.g. "--port abc") yields a + // clear message instead of an uncaught NumberFormatException stack trace. No System.exit here + // — the noSystemExit architecture rule forbids it; print to stderr and return like the + // missing-"--model" path above. + try { + String ctx = opts.get("ctx"); + if (ctx != null) { + int ctxSize = Integer.parseInt(ctx); + modelParams.setCtxSize(ctxSize); + cfg.maxOutputTokens(Math.min(OpenAiServerConfig.DEFAULT_MAX_OUTPUT_TOKENS, Math.max(1, ctxSize / 2))); + cfg.maxInputTokens(Math.max(1, ctxSize - OpenAiServerConfig.DEFAULT_MAX_OUTPUT_TOKENS)); + } + String gpuLayers = opts.get("gpu-layers"); + if (gpuLayers != null) { + modelParams.setGpuLayers(Integer.parseInt(gpuLayers)); + } + String parallel = opts.get("parallel"); + if (parallel != null) { + modelParams.setParallel(Integer.parseInt(parallel)); + } + String port = opts.get("port"); + if (port != null) { + cfg.port(Integer.parseInt(port)); + } + } catch (NumberFormatException e) { + System.err.println("Invalid numeric option (expected an integer): " + e.getMessage()); + return; + } + + OpenAiServerConfig config = cfg.build(); + + LlamaModel model = new LlamaModel(modelParams); + OpenAiCompatServer server = new OpenAiCompatServer(model, config); + Runtime.getRuntime() + .addShutdownHook(new Thread( + () -> { + server.close(); + model.close(); + }, + "jllama-openai-shutdown")); + server.start(); + printReady(config, server.getPort()); + try { + Thread.currentThread().join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private static Map parseArgs(String[] args) { + Map opts = new HashMap<>(); + for (int i = 0; i < args.length; i++) { + String arg = args[i]; + if (arg.startsWith("--") && i + 1 < args.length) { + opts.put(arg.substring(2), args[i + 1]); + i++; + } + } + return opts; + } + + private static void printReady(OpenAiServerConfig config, int port) { + String url = "http://" + config.getHost() + ":" + port + PATH_CHAT_COMPLETIONS; + System.out.println(); + System.out.println("OpenAI-compatible endpoint ready: " + url); + System.out.println("Add this to VS Code's chatLanguageModels.json (Chat: Manage Language Models):"); + System.out.println("["); + System.out.println(" {"); + System.out.println(" \"name\": \"Local llama.cpp (java-llama.cpp)\","); + System.out.println(" \"vendor\": \"customendpoint\","); + System.out.println( + " \"apiKey\": \"" + (config.isAuthenticationEnabled() ? "" : "local-dummy-key") + "\","); + System.out.println(" \"apiType\": \"chat-completions\","); + System.out.println(" \"models\": ["); + System.out.println(" {"); + System.out.println(" \"id\": \"" + config.getModelId() + "\","); + System.out.println(" \"name\": \"" + config.getModelId() + "\","); + System.out.println(" \"url\": \"" + url + "\","); + System.out.println(" \"toolCalling\": true,"); + System.out.println(" \"vision\": false,"); + System.out.println(" \"maxInputTokens\": " + config.getMaxInputTokens() + ","); + System.out.println(" \"maxOutputTokens\": " + config.getMaxOutputTokens()); + System.out.println(" }"); + System.out.println(" ]"); + System.out.println(" }"); + System.out.println("]"); + } +} diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java b/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java new file mode 100644 index 00000000..aa01487a --- /dev/null +++ b/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import com.fasterxml.jackson.databind.JsonNode; +import java.util.ArrayList; +import java.util.List; +import net.ladenthin.llama.parameters.InferenceParameters; + +/** + * Pure mapping from an OpenAI {@code /v1/chat/completions} request body to {@link InferenceParameters}. + * + *

The structural fields — {@code messages}, {@code tools}, {@code tool_choice} — are forwarded + * verbatim as raw JSON so the full OpenAI shape (assistant {@code tool_calls}, + * {@code role:"tool"} results with {@code tool_call_id}, and vision {@code image_url} content parts) + * round-trips untouched into the native chat-template parser. Sampling fields are translated to the + * matching {@code InferenceParameters.with*} setter; unknown fields are ignored. + * + *

The {@code stream} flag is intentionally not mapped here — streaming is selected by the caller + * ({@link net.ladenthin.llama.LlamaModel#chatComplete} forces it off, + * {@link net.ladenthin.llama.LlamaModel#streamChatCompletion} forces it on). Stateless and free of JNI + * and model dependencies, so it is unit-testable with JSON literals alone. + */ +final class OpenAiRequestMapper { + + OpenAiRequestMapper() {} + + /** + * Translate an OpenAI chat request into {@link InferenceParameters}. + * + * @param request the parsed OpenAI request object + * @return inference parameters carrying the verbatim messages and mapped sampling options + * @throws IllegalArgumentException if {@code messages} is missing or not a non-empty array + */ + InferenceParameters toInferenceParameters(JsonNode request) { + JsonNode messages = request.path("messages"); + if (!messages.isArray() || messages.size() == 0) { + throw new IllegalArgumentException("'messages' must be a non-empty array"); + } + + InferenceParameters params = InferenceParameters.empty().withMessagesJson(messages.toString()); + + JsonNode tools = request.path("tools"); + if (tools.isArray() && tools.size() > 0) { + params = params.withToolsJson(tools.toString()).withUseChatTemplate(true); + JsonNode toolChoice = request.path("tool_choice"); + if (toolChoice.isTextual()) { + params = params.withToolChoice(toolChoice.asText()); + } + } + + JsonNode temperature = request.path("temperature"); + if (temperature.isNumber()) { + params = params.withTemperature((float) temperature.asDouble()); + } + JsonNode topP = request.path("top_p"); + if (topP.isNumber()) { + params = params.withTopP((float) topP.asDouble()); + } + JsonNode topK = request.path("top_k"); + if (topK.isNumber()) { + params = params.withTopK(topK.asInt()); + } + JsonNode seed = request.path("seed"); + if (seed.isNumber()) { + params = params.withSeed(seed.asInt()); + } + JsonNode presencePenalty = request.path("presence_penalty"); + if (presencePenalty.isNumber()) { + params = params.withPresencePenalty((float) presencePenalty.asDouble()); + } + JsonNode frequencyPenalty = request.path("frequency_penalty"); + if (frequencyPenalty.isNumber()) { + params = params.withFrequencyPenalty((float) frequencyPenalty.asDouble()); + } + + int maxTokens = readMaxTokens(request); + if (maxTokens > 0) { + params = params.withNPredict(maxTokens); + } + + String[] stops = readStops(request); + if (stops.length > 0) { + params = params.withStopStrings(stops); + } + + return params; + } + + /** + * Read the output-token cap, preferring the newer {@code max_completion_tokens} over the legacy + * {@code max_tokens}. + * + * @param request the parsed OpenAI request object + * @return the requested cap, or {@code -1} when neither field is a number + */ + private int readMaxTokens(JsonNode request) { + JsonNode maxCompletion = request.path("max_completion_tokens"); + if (maxCompletion.isNumber()) { + return maxCompletion.asInt(); + } + JsonNode maxTokens = request.path("max_tokens"); + if (maxTokens.isNumber()) { + return maxTokens.asInt(); + } + return -1; + } + + /** + * Read the {@code stop} field, which OpenAI permits as either a single string or an array of + * strings. + * + * @param request the parsed OpenAI request object + * @return the stop strings (possibly empty, never {@code null}) + */ + private String[] readStops(JsonNode request) { + JsonNode stop = request.path("stop"); + List stops = new ArrayList<>(); + if (stop.isTextual()) { + stops.add(stop.asText()); + } else if (stop.isArray()) { + for (JsonNode entry : stop) { + if (entry.isTextual()) { + stops.add(entry.asText()); + } + } + } + return stops.toArray(new String[0]); + } +} diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiServerConfig.java b/src/main/java/net/ladenthin/llama/server/OpenAiServerConfig.java new file mode 100644 index 00000000..098512aa --- /dev/null +++ b/src/main/java/net/ladenthin/llama/server/OpenAiServerConfig.java @@ -0,0 +1,257 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import org.jspecify.annotations.Nullable; + +/** + * Immutable configuration for {@link OpenAiCompatServer}. + * + *

Sensible localhost defaults are provided; build instances with {@link #builder()}. The API key is + * deliberately excluded from {@link #toString()} so it is never written to logs. + */ +public final class OpenAiServerConfig { + + /** Default bind address: loopback only, so the endpoint is not exposed off-host. */ + public static final String DEFAULT_HOST = "127.0.0.1"; + + /** Default TCP port. */ + public static final int DEFAULT_PORT = 8080; + + /** Default advertised model id (the {@code id} echoed by {@code GET /v1/models}). */ + public static final String DEFAULT_MODEL_ID = "local-model"; + + /** Default advertised maximum input tokens. */ + public static final int DEFAULT_MAX_INPUT_TOKENS = 8192; + + /** Default advertised maximum output tokens. */ + public static final int DEFAULT_MAX_OUTPUT_TOKENS = 2048; + + /** Default Server-Sent-Events heartbeat interval, in milliseconds. */ + public static final long DEFAULT_HEARTBEAT_MILLIS = 15_000L; + + private final String host; + private final int port; + private final @Nullable String apiKey; + private final String modelId; + private final int maxInputTokens; + private final int maxOutputTokens; + private final long heartbeatMillis; + + private OpenAiServerConfig(Builder builder) { + this.host = builder.host; + this.port = builder.port; + this.apiKey = builder.apiKey; + this.modelId = builder.modelId; + this.maxInputTokens = builder.maxInputTokens; + this.maxOutputTokens = builder.maxOutputTokens; + this.heartbeatMillis = builder.heartbeatMillis; + } + + /** + * Returns a new builder seeded with the localhost defaults. + * + * @return a fresh {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * The bind address (loopback by default). + * + * @return the host the server binds to + */ + public String getHost() { + return host; + } + + /** + * The TCP port. + * + * @return the port the server listens on + */ + public int getPort() { + return port; + } + + /** + * The optional bearer API key. When {@code null}, no {@code Authorization} header is required. + * + * @return the configured API key, or {@code null} when authentication is disabled + */ + public @Nullable String getApiKey() { + return apiKey; + } + + /** + * The advertised model id. + * + * @return the model id reported by {@code GET /v1/models} + */ + public String getModelId() { + return modelId; + } + + /** + * The advertised maximum input-token budget. + * + * @return the advertised max input tokens + */ + public int getMaxInputTokens() { + return maxInputTokens; + } + + /** + * The advertised maximum output-token budget. + * + * @return the advertised max output tokens + */ + public int getMaxOutputTokens() { + return maxOutputTokens; + } + + /** + * The Server-Sent-Events heartbeat interval. + * + * @return the heartbeat interval in milliseconds + */ + public long getHeartbeatMillis() { + return heartbeatMillis; + } + + /** + * Whether bearer-token authentication is enabled (an API key is configured). + * + * @return {@code true} if requests must present a matching bearer token + */ + public boolean isAuthenticationEnabled() { + return apiKey != null && !apiKey.isEmpty(); + } + + /** + * Renders the configuration without exposing the API key. + * + * @return a log-safe description of this configuration + */ + @Override + public String toString() { + return "OpenAiServerConfig{host=" + + host + + ", port=" + + port + + ", authEnabled=" + + isAuthenticationEnabled() + + ", modelId=" + + modelId + + ", maxInputTokens=" + + maxInputTokens + + ", maxOutputTokens=" + + maxOutputTokens + + ", heartbeatMillis=" + + heartbeatMillis + + '}'; + } + + /** Mutable builder for {@link OpenAiServerConfig}; not thread-safe. */ + public static final class Builder { + + private String host = DEFAULT_HOST; + private int port = DEFAULT_PORT; + private @Nullable String apiKey; + private String modelId = DEFAULT_MODEL_ID; + private int maxInputTokens = DEFAULT_MAX_INPUT_TOKENS; + private int maxOutputTokens = DEFAULT_MAX_OUTPUT_TOKENS; + private long heartbeatMillis = DEFAULT_HEARTBEAT_MILLIS; + + private Builder() {} + + /** + * Sets the bind address. + * + * @param host the host to bind (e.g. {@code "127.0.0.1"}) + * @return this builder + */ + public Builder host(String host) { + this.host = host; + return this; + } + + /** + * Sets the TCP port. + * + * @param port the port to listen on + * @return this builder + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * Sets the optional bearer API key. Pass {@code null} (the default) to disable authentication. + * + * @param apiKey the required bearer token, or {@code null} for no authentication + * @return this builder + */ + public Builder apiKey(@Nullable String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Sets the advertised model id. + * + * @param modelId the model id to advertise + * @return this builder + */ + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + /** + * Sets the advertised maximum input tokens. + * + * @param maxInputTokens the advertised max input tokens + * @return this builder + */ + public Builder maxInputTokens(int maxInputTokens) { + this.maxInputTokens = maxInputTokens; + return this; + } + + /** + * Sets the advertised maximum output tokens. + * + * @param maxOutputTokens the advertised max output tokens + * @return this builder + */ + public Builder maxOutputTokens(int maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + return this; + } + + /** + * Sets the Server-Sent-Events heartbeat interval. + * + * @param heartbeatMillis the heartbeat interval in milliseconds + * @return this builder + */ + public Builder heartbeatMillis(long heartbeatMillis) { + this.heartbeatMillis = heartbeatMillis; + return this; + } + + /** + * Builds the immutable configuration. + * + * @return a new {@link OpenAiServerConfig} + */ + public OpenAiServerConfig build() { + return new OpenAiServerConfig(this); + } + } +} diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiSseFormatter.java b/src/main/java/net/ladenthin/llama/server/OpenAiSseFormatter.java new file mode 100644 index 00000000..9a9d4b5f --- /dev/null +++ b/src/main/java/net/ladenthin/llama/server/OpenAiSseFormatter.java @@ -0,0 +1,94 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.jspecify.annotations.Nullable; + +/** + * Pure formatting helpers for the OpenAI HTTP surface: Server-Sent-Events framing, the {@code [DONE]} + * terminator, heartbeat comments, the {@code GET /v1/models} body, and the OpenAI error envelope. + * + *

Stateless and free of JNI / model dependencies, so each helper is unit-testable with literals. + */ +final class OpenAiSseFormatter { + + /** Shared Jackson mapper; thread-safe and reused. */ + static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private OpenAiSseFormatter() {} + + /** + * Frame a chunk's JSON as one SSE {@code data:} event. + * + * @param json the chunk JSON to send + * @return the SSE event text, terminated by a blank line + */ + static String sseData(String json) { + return "data: " + json + "\n\n"; + } + + /** + * The terminating SSE event that marks the end of an OpenAI stream. + * + * @return {@code "data: [DONE]\n\n"} + */ + static String sseDone() { + return "data: [DONE]\n\n"; + } + + /** + * An SSE comment line used as a keep-alive heartbeat. OpenAI clients ignore comment lines, but the + * bytes reset the client's stream-inactivity timer during long prompt prefill. + * + * @return {@code ": ping\n\n"} + */ + static String heartbeat() { + return ": ping\n\n"; + } + + /** + * Build an OpenAI error envelope: {@code {"error":{"message":…,"type":…,"code":…}}}. + * + * @param message human-readable error message + * @param type OpenAI error type (e.g. {@code "invalid_request_error"}, {@code "server_error"}) + * @param code optional machine-readable code; {@code null} renders as JSON {@code null} + * @return the error envelope serialized as JSON + */ + static String errorJson(String message, String type, @Nullable String code) { + ObjectNode error = OBJECT_MAPPER.createObjectNode(); + error.put("message", message); + error.put("type", type); + if (code != null) { + error.put("code", code); + } else { + error.putNull("code"); + } + ObjectNode root = OBJECT_MAPPER.createObjectNode(); + root.set("error", error); + return root.toString(); + } + + /** + * Build the {@code GET /v1/models} body advertising a single model. + * + * @param modelId the model id to advertise + * @return an OpenAI model-list object serialized as JSON + */ + static String modelsJson(String modelId) { + ObjectNode model = OBJECT_MAPPER.createObjectNode(); + model.put("id", modelId); + model.put("object", "model"); + model.put("owned_by", "llama.cpp"); + ArrayNode data = OBJECT_MAPPER.createArrayNode(); + data.add(model); + ObjectNode root = OBJECT_MAPPER.createObjectNode(); + root.put("object", "list"); + root.set("data", data); + return root.toString(); + } +} diff --git a/src/main/java/net/ladenthin/llama/server/package-info.java b/src/main/java/net/ladenthin/llama/server/package-info.java index 4c642b28..ff1eb971 100644 --- a/src/main/java/net/ladenthin/llama/server/package-info.java +++ b/src/main/java/net/ladenthin/llama/server/package-info.java @@ -3,21 +3,34 @@ // SPDX-License-Identifier: MIT /** - * Optional, self-contained OpenAI-compatible HTTP server built on the in-process - * {@link net.ladenthin.llama.LlamaModel} API. + * Optional OpenAI-compatible HTTP server over a loaded {@link net.ladenthin.llama.LlamaModel}. * - *

{@link net.ladenthin.llama.server.LlamaServer} is the {@code main} entry point (and the - * {@code Main-Class} of the {@code -jar-with-dependencies} assembly). It loads a GGUF model and - * exposes {@code POST /v1/chat/completions}, {@code POST /v1/completions}, - * {@code POST /v1/embeddings} and {@code GET /v1/models} by forwarding the request body to the - * matching {@code LlamaModel.handle*} method, which already returns OpenAI-shaped JSON.

+ *

Interim state — two implementations pending consolidation. This package + * currently contains two independent OpenAI-compatible server implementations that landed + * on separate branches and are awaiting a "best of both" merge (tracked in {@code TODO.md}). Both + * let editors and tools that speak the OpenAI Chat Completions protocol (for example a VS Code + * Copilot "Custom Endpoint") drive a local GGUF model running in-process through the JNI binding, + * and both are faithful pass-throughs that do not implement or execute tools themselves.

* - *

The HTTP layer is NanoHTTPD (a tiny, dependency-free, Java 8 server). The dependency is - * declared {@code } so it is bundled in the fat jar but not inherited by library - * consumers. The routing logic ({@link net.ladenthin.llama.server.OaiRouter}) is decoupled from - * NanoHTTPD so it can be unit-tested without binding a socket or loading a model.

+ *
    + *
  • {@link net.ladenthin.llama.server.LlamaServer} / + * {@link net.ladenthin.llama.server.OaiHttpServer} — a NanoHTTPD-based server. + * {@code LlamaServer} is a {@code main} entry point (and the {@code Main-Class} of the + * {@code -jar-with-dependencies} assembly). It exposes {@code POST /v1/chat/completions}, + * {@code POST /v1/completions}, {@code POST /v1/embeddings} and {@code GET /v1/models} by + * forwarding the request body to the matching {@code LlamaModel.handle*} method, which already + * returns OpenAI-shaped JSON. Routing ({@link net.ladenthin.llama.server.OaiRouter}) is + * decoupled from NanoHTTPD so it is unit-testable without binding a socket or loading a model. + * NanoHTTPD is an {@code } dependency (bundled only in the fat jar).
  • + *
  • {@link net.ladenthin.llama.server.OpenAiCompatServer} — a dependency-free server built only + * on the JDK's {@code com.sun.net.httpserver.HttpServer}. It serves + * {@code POST /v1/chat/completions} (streaming via Server-Sent Events and non-streaming) and + * {@code GET /v1/models}. Streaming comes straight from the native OpenAI chunk formatter (see + * {@link net.ladenthin.llama.LlamaModel#streamChatCompletion(net.ladenthin.llama.parameters.InferenceParameters, java.util.function.Consumer)}), + * so streamed {@code delta.tool_calls} are preserved for agent-mode tool use.
  • + *
* - *

JSpecify {@code @NullMarked} is applied module-wide; everything is non-null unless annotated - * {@code @Nullable}.

+ *

JSpecify {@code @NullMarked} is applied module-wide (see {@code module-info.java}) and applies + * to this package transitively.

*/ package net.ladenthin.llama.server; diff --git a/src/test/cpp/test_jni_helpers.cpp b/src/test/cpp/test_jni_helpers.cpp index 1bc074d6..ccc2c889 100644 --- a/src/test/cpp/test_jni_helpers.cpp +++ b/src/test/cpp/test_jni_helpers.cpp @@ -22,10 +22,6 @@ #include -#include -#include -#include -#include #include "server-context.h" #include "server-queue.h" #include "server-task.h" @@ -33,6 +29,10 @@ #include "server-chat.h" #include "utils.hpp" #include "jni_helpers.hpp" +#include +#include +#include +#include // embedding_to_jfloat_array_impl and tokens_to_jint_array_impl are also tested // in this file (see bottom). @@ -58,21 +58,19 @@ static server_task_result_ptr make_ok(int id_, const std::string &msg = "ok") { // ============================================================ // State captured by stubs — reset in each fixture's SetUp(). -static bool g_throw_called = false; +static bool g_throw_called = false; static std::string g_throw_message; static std::string g_new_string_utf_value; -static jlong g_mock_handle = 0; +static jlong g_mock_handle = 0; static jstring g_new_string_utf_sentinel = reinterpret_cast(0xBEEF); -static jint JNICALL stub_ThrowNew(JNIEnv *, jclass, const char *msg) { - g_throw_called = true; +static jint JNICALL stub_ThrowNew(JNIEnv *, jclass, const char *msg) { + g_throw_called = true; g_throw_message = msg ? msg : ""; return 0; } -static jlong JNICALL stub_GetLongField(JNIEnv *, jobject, jfieldID) { - return g_mock_handle; -} +static jlong JNICALL stub_GetLongField(JNIEnv *, jobject, jfieldID) { return g_mock_handle; } static jstring JNICALL stub_NewStringUTF(JNIEnv *, const char *utf) { g_new_string_utf_value = utf ? utf : ""; return g_new_string_utf_sentinel; @@ -81,25 +79,25 @@ static jstring JNICALL stub_NewStringUTF(JNIEnv *, const char *utf) { // Minimal env: ThrowNew + GetLongField + NewStringUTF. JNIEnv *make_mock_env(JNINativeInterface_ &table, JNIEnv_ &env_obj) { std::memset(&table, 0, sizeof(table)); - table.ThrowNew = stub_ThrowNew; + table.ThrowNew = stub_ThrowNew; table.GetLongField = stub_GetLongField; table.NewStringUTF = stub_NewStringUTF; - env_obj.functions = &table; + env_obj.functions = &table; return &env_obj; } // Base fixture: resets all mock state. struct MockJniFixture : ::testing::Test { JNINativeInterface_ table{}; - JNIEnv_ env_obj{}; - JNIEnv *env = nullptr; - jfieldID dummy_field = reinterpret_cast(0x1); - jclass dummy_class = reinterpret_cast(0x2); + JNIEnv_ env_obj{}; + JNIEnv *env = nullptr; + jfieldID dummy_field = reinterpret_cast(0x1); + jclass dummy_class = reinterpret_cast(0x2); void SetUp() override { - env = make_mock_env(table, env_obj); - g_mock_handle = 0; - g_throw_called = false; + env = make_mock_env(table, env_obj); + g_mock_handle = 0; + g_throw_called = false; g_throw_message.clear(); g_new_string_utf_value.clear(); } @@ -297,8 +295,7 @@ TEST_F(MockJniFixture, RequireJsonField_MissingField_ReturnsFalseAndThrows) { } TEST_F(MockJniFixture, RequireJsonField_EmptyJson_ReturnsFalseAndThrows) { - EXPECT_FALSE(require_json_field_impl( - env, nlohmann::json::object(), "input_suffix", dummy_class)); + EXPECT_FALSE(require_json_field_impl(env, nlohmann::json::object(), "input_suffix", dummy_class)); EXPECT_TRUE(g_throw_called); EXPECT_EQ(g_throw_message, "\"input_suffix\" is required"); } @@ -319,40 +316,38 @@ TEST_F(MockJniFixture, RequireJsonField_NullValue_ReturnsTrueNoThrow) { namespace { -static jint g_array_data[8] = {}; -static jsize g_array_length = 0; -static bool g_release_called = false; -static jint g_release_mode = -1; +static jint g_array_data[8] = {}; +static jsize g_array_length = 0; +static bool g_release_called = false; +static jint g_release_mode = -1; static jsize JNICALL stub_GetArrayLength(JNIEnv *, jarray) { return g_array_length; } -static jint *JNICALL stub_GetIntArrayElements(JNIEnv *, jintArray, jboolean *) { - return g_array_data; -} -static void JNICALL stub_ReleaseIntArrayElements(JNIEnv *, jintArray, jint *, jint mode) { +static jint *JNICALL stub_GetIntArrayElements(JNIEnv *, jintArray, jboolean *) { return g_array_data; } +static void JNICALL stub_ReleaseIntArrayElements(JNIEnv *, jintArray, jint *, jint mode) { g_release_called = true; - g_release_mode = mode; + g_release_mode = mode; } JNIEnv *make_array_env(JNINativeInterface_ &table, JNIEnv_ &env_obj) { std::memset(&table, 0, sizeof(table)); - table.GetArrayLength = stub_GetArrayLength; - table.GetIntArrayElements = stub_GetIntArrayElements; + table.GetArrayLength = stub_GetArrayLength; + table.GetIntArrayElements = stub_GetIntArrayElements; table.ReleaseIntArrayElements = stub_ReleaseIntArrayElements; - env_obj.functions = &table; + env_obj.functions = &table; return &env_obj; } struct ArrayFixture : ::testing::Test { JNINativeInterface_ table{}; - JNIEnv_ env_obj{}; - JNIEnv *env = nullptr; + JNIEnv_ env_obj{}; + JNIEnv *env = nullptr; void SetUp() override { - env = make_array_env(table, env_obj); + env = make_array_env(table, env_obj); g_release_called = false; - g_release_mode = -1; + g_release_mode = -1; std::memset(g_array_data, 0, sizeof(g_array_data)); - g_array_length = 0; + g_array_length = 0; } }; @@ -367,8 +362,10 @@ TEST_F(ArrayFixture, JintArrayToTokens_EmptyArray_ReturnsEmptyVector) { } TEST_F(ArrayFixture, JintArrayToTokens_ThreeElements_CopiedCorrectly) { - g_array_data[0] = 10; g_array_data[1] = 20; g_array_data[2] = 30; - g_array_length = 3; + g_array_data[0] = 10; + g_array_data[1] = 20; + g_array_data[2] = 30; + g_array_length = 3; auto tokens = jint_array_to_tokens_impl(env, nullptr); ASSERT_EQ(tokens.size(), 3u); EXPECT_EQ(tokens[0], 10); @@ -377,7 +374,8 @@ TEST_F(ArrayFixture, JintArrayToTokens_ThreeElements_CopiedCorrectly) { } TEST_F(ArrayFixture, JintArrayToTokens_ReleasesWithAbortFlag) { - g_array_length = 1; g_array_data[0] = 42; + g_array_length = 1; + g_array_data[0] = 42; (void)jint_array_to_tokens_impl(env, nullptr); EXPECT_TRUE(g_release_called); EXPECT_EQ(g_release_mode, JNI_ABORT); @@ -463,8 +461,8 @@ TEST_F(MockJniFixture, ResultsToJstring_EmptyVector_ReturnsEmptyArray) { namespace { -static bool g_float_new_called = false; -static jsize g_float_alloc_size = -1; +static bool g_float_new_called = false; +static jsize g_float_alloc_size = -1; static jsize g_float_copied_size = -1; static jfloatArray JNICALL stub_NewFloatArray(JNIEnv *, jsize n) { @@ -479,10 +477,10 @@ static void JNICALL stub_SetFloatArrayRegion(JNIEnv *, jfloatArray, jsize, jsize struct FloatArrayFixture : MockJniFixture { void SetUp() override { MockJniFixture::SetUp(); - g_float_new_called = false; - g_float_alloc_size = -1; + g_float_new_called = false; + g_float_alloc_size = -1; g_float_copied_size = -1; - table.NewFloatArray = stub_NewFloatArray; + table.NewFloatArray = stub_NewFloatArray; table.SetFloatArrayRegion = stub_SetFloatArrayRegion; } }; @@ -529,8 +527,8 @@ TEST_F(FloatArrayFixture, EmbeddingToJfloatArray_AllocFails_ThrowsOomAndReturnsN namespace { -static bool g_int_new_called = false; -static jsize g_int_alloc_size = -1; +static bool g_int_new_called = false; +static jsize g_int_alloc_size = -1; static jsize g_int_copied_size = -1; static jintArray JNICALL stub_NewIntArray(JNIEnv *, jsize n) { @@ -538,17 +536,15 @@ static jintArray JNICALL stub_NewIntArray(JNIEnv *, jsize n) { g_int_alloc_size = n; return reinterpret_cast(0xF2); } -static void JNICALL stub_SetIntArrayRegion(JNIEnv *, jintArray, jsize, jsize n, const jint *) { - g_int_copied_size = n; -} +static void JNICALL stub_SetIntArrayRegion(JNIEnv *, jintArray, jsize, jsize n, const jint *) { g_int_copied_size = n; } struct IntArrayFixture : MockJniFixture { void SetUp() override { MockJniFixture::SetUp(); - g_int_new_called = false; - g_int_alloc_size = -1; + g_int_new_called = false; + g_int_alloc_size = -1; g_int_copied_size = -1; - table.NewIntArray = stub_NewIntArray; + table.NewIntArray = stub_NewIntArray; table.SetIntArrayRegion = stub_SetIntArrayRegion; } }; diff --git a/src/test/cpp/test_json_helpers.cpp b/src/test/cpp/test_json_helpers.cpp index 6fa579e4..a5afb95e 100644 --- a/src/test/cpp/test_json_helpers.cpp +++ b/src/test/cpp/test_json_helpers.cpp @@ -20,6 +20,7 @@ // is_infill_request // parse_slot_prompt_similarity // parse_positive_int_config +// wrap_stream_chunk #include @@ -45,9 +46,9 @@ namespace { // to_json() → format_error_response() → {"message": msg, ...} matches the // exact JSON key that get_result_error_message reads. static server_task_result_ptr make_error(int id_, const std::string &msg) { - auto r = std::make_unique(); - r->id = id_; - r->err_msg = msg; + auto r = std::make_unique(); + r->id = id_; + r->err_msg = msg; r->err_type = ERROR_TYPE_SERVER; return r; } @@ -69,14 +70,13 @@ struct fake_embedding_result : server_task_result { std::vector vec; int tokens_evaluated; explicit fake_embedding_result(int id_, std::vector v, int tok = 4) - : vec(std::move(v)), tokens_evaluated(tok) { id = id_; } - json to_json() override { - return {{"embedding", vec}, {"tokens_evaluated", tokens_evaluated}}; + : vec(std::move(v)), tokens_evaluated(tok) { + id = id_; } + json to_json() override { return {{"embedding", vec}, {"tokens_evaluated", tokens_evaluated}}; } }; -static server_task_result_ptr make_embedding(int id_, - std::vector v = {0.1f, 0.2f, 0.3f}) { +static server_task_result_ptr make_embedding(int id_, std::vector v = {0.1f, 0.2f, 0.3f}) { return std::make_unique(id_, std::move(v)); } @@ -163,7 +163,8 @@ TEST(ResultsToJson, SingleErrorResult_ReturnsObjectDirectly) { namespace { struct fake_rerank_result : server_task_result { - int index; float score; + int index; + float score; fake_rerank_result(int id_, int idx, float sc) : index(idx), score(sc) { id = id_; } json to_json() override { return {{"index", index}, {"score", score}}; } }; @@ -244,26 +245,20 @@ TEST(RerankResultsToJson, PreservesInputOrder) { // parse_encoding_format // ============================================================ -TEST(ParseEncodingFormat, FieldAbsent_ReturnsFalse) { - EXPECT_FALSE(parse_encoding_format({{"model", "x"}})); -} +TEST(ParseEncodingFormat, FieldAbsent_ReturnsFalse) { EXPECT_FALSE(parse_encoding_format({{"model", "x"}})); } TEST(ParseEncodingFormat, ExplicitFloat_ReturnsFalse) { EXPECT_FALSE(parse_encoding_format({{"encoding_format", "float"}})); } -TEST(ParseEncodingFormat, Base64_ReturnsTrue) { - EXPECT_TRUE(parse_encoding_format({{"encoding_format", "base64"}})); -} +TEST(ParseEncodingFormat, Base64_ReturnsTrue) { EXPECT_TRUE(parse_encoding_format({{"encoding_format", "base64"}})); } TEST(ParseEncodingFormat, UnknownFormat_ThrowsInvalidArgument) { - EXPECT_THROW((void)parse_encoding_format({{"encoding_format", "binary"}}), - std::invalid_argument); + EXPECT_THROW((void)parse_encoding_format({{"encoding_format", "binary"}}), std::invalid_argument); } TEST(ParseEncodingFormat, EmptyString_ThrowsInvalidArgument) { - EXPECT_THROW((void)parse_encoding_format({{"encoding_format", ""}}), - std::invalid_argument); + EXPECT_THROW((void)parse_encoding_format({{"encoding_format", ""}}), std::invalid_argument); } TEST(ParseEncodingFormat, ErrorMessage_MentionsBothValidOptions) { @@ -272,7 +267,7 @@ TEST(ParseEncodingFormat, ErrorMessage_MentionsBothValidOptions) { FAIL() << "Expected std::invalid_argument"; } catch (const std::invalid_argument &e) { const std::string msg(e.what()); - EXPECT_NE(msg.find("float"), std::string::npos); + EXPECT_NE(msg.find("float"), std::string::npos); EXPECT_NE(msg.find("base64"), std::string::npos); } } @@ -297,28 +292,24 @@ TEST(ExtractEmbeddingPrompt, ContentKey_ReturnsValueAndSetsFlag) { TEST(ExtractEmbeddingPrompt, InputTakesPriorityOverContent) { bool flag = false; - json prompt = extract_embedding_prompt( - {{"input", "from input"}, {"content", "from content"}}, flag); + json prompt = extract_embedding_prompt({{"input", "from input"}, {"content", "from content"}}, flag); EXPECT_EQ(prompt, "from input"); EXPECT_FALSE(flag); } TEST(ExtractEmbeddingPrompt, NeitherKey_ThrowsInvalidArgument) { bool flag = false; - EXPECT_THROW((void)extract_embedding_prompt({{"model", "x"}}, flag), - std::invalid_argument); + EXPECT_THROW((void)extract_embedding_prompt({{"model", "x"}}, flag), std::invalid_argument); } TEST(ExtractEmbeddingPrompt, EmptyBody_ThrowsInvalidArgument) { bool flag = false; - EXPECT_THROW((void)extract_embedding_prompt(json::object(), flag), - std::invalid_argument); + EXPECT_THROW((void)extract_embedding_prompt(json::object(), flag), std::invalid_argument); } TEST(ExtractEmbeddingPrompt, ArrayPrompt_ReturnedAsIs) { bool flag = false; - json prompt = extract_embedding_prompt( - {{"input", {"sentence one", "sentence two"}}}, flag); + json prompt = extract_embedding_prompt({{"input", {"sentence one", "sentence two"}}}, flag); ASSERT_TRUE(prompt.is_array()); ASSERT_EQ(prompt.size(), 2u); EXPECT_EQ(prompt[0], "sentence one"); @@ -330,26 +321,17 @@ TEST(ExtractEmbeddingPrompt, ArrayPrompt_ReturnedAsIs) { // is_infill_request // ============================================================ -TEST(IsInfillRequest, HasInputPrefix_ReturnsTrue) { - EXPECT_TRUE(is_infill_request({{"input_prefix", "def f():"}})); -} +TEST(IsInfillRequest, HasInputPrefix_ReturnsTrue) { EXPECT_TRUE(is_infill_request({{"input_prefix", "def f():"}})); } -TEST(IsInfillRequest, HasInputSuffix_ReturnsTrue) { - EXPECT_TRUE(is_infill_request({{"input_suffix", "return 1"}})); -} +TEST(IsInfillRequest, HasInputSuffix_ReturnsTrue) { EXPECT_TRUE(is_infill_request({{"input_suffix", "return 1"}})); } TEST(IsInfillRequest, HasBoth_ReturnsTrue) { - EXPECT_TRUE(is_infill_request( - {{"input_prefix", "def f():"}, {"input_suffix", "return 1"}})); + EXPECT_TRUE(is_infill_request({{"input_prefix", "def f():"}, {"input_suffix", "return 1"}})); } -TEST(IsInfillRequest, HasNeither_ReturnsFalse) { - EXPECT_FALSE(is_infill_request({{"prompt", "hello"}})); -} +TEST(IsInfillRequest, HasNeither_ReturnsFalse) { EXPECT_FALSE(is_infill_request({{"prompt", "hello"}})); } -TEST(IsInfillRequest, EmptyBody_ReturnsFalse) { - EXPECT_FALSE(is_infill_request(json::object())); -} +TEST(IsInfillRequest, EmptyBody_ReturnsFalse) { EXPECT_FALSE(is_infill_request(json::object())); } // ============================================================ // parse_slot_prompt_similarity @@ -378,15 +360,11 @@ TEST(ParseSlotPromptSimilarity, One_ReturnsOne) { } TEST(ParseSlotPromptSimilarity, TooLow_ThrowsInvalidArgument) { - EXPECT_THROW( - (void)parse_slot_prompt_similarity({{"slot_prompt_similarity", -0.1f}}), - std::invalid_argument); + EXPECT_THROW((void)parse_slot_prompt_similarity({{"slot_prompt_similarity", -0.1f}}), std::invalid_argument); } TEST(ParseSlotPromptSimilarity, TooHigh_ThrowsInvalidArgument) { - EXPECT_THROW( - (void)parse_slot_prompt_similarity({{"slot_prompt_similarity", 1.1f}}), - std::invalid_argument); + EXPECT_THROW((void)parse_slot_prompt_similarity({{"slot_prompt_similarity", 1.1f}}), std::invalid_argument); } // ============================================================ @@ -410,13 +388,11 @@ TEST(ParsePositiveIntConfig, ValidLarge_ReturnsValue) { } TEST(ParsePositiveIntConfig, Zero_ThrowsInvalidArgument) { - EXPECT_THROW((void)parse_positive_int_config({{"n_threads", 0}}, "n_threads"), - std::invalid_argument); + EXPECT_THROW((void)parse_positive_int_config({{"n_threads", 0}}, "n_threads"), std::invalid_argument); } TEST(ParsePositiveIntConfig, Negative_ThrowsInvalidArgument) { - EXPECT_THROW((void)parse_positive_int_config({{"n_threads", -4}}, "n_threads"), - std::invalid_argument); + EXPECT_THROW((void)parse_positive_int_config({{"n_threads", -4}}, "n_threads"), std::invalid_argument); } TEST(ParsePositiveIntConfig, ErrorMessage_ContainsKeyName) { @@ -427,3 +403,43 @@ TEST(ParsePositiveIntConfig, ErrorMessage_ContainsKeyName) { EXPECT_NE(std::string(e.what()).find("n_threads_batch"), std::string::npos); } } + +// ============================================================ +// wrap_stream_chunk +// ============================================================ + +TEST(WrapStreamChunk, ObjectPayload_NotStopped) { + json chunk = {{"object", "chat.completion.chunk"}, {"choices", json::array({{{"delta", {{"content", "hi"}}}}})}}; + json out = wrap_stream_chunk(chunk, false); + EXPECT_FALSE(out.at("stop").get()); + ASSERT_TRUE(out.at("data").is_object()); + EXPECT_EQ(out.at("data").at("object").get(), "chat.completion.chunk"); +} + +TEST(WrapStreamChunk, ArrayPayload_Stopped) { + json final_chunks = + json::array({{{"choices", json::array({{{"finish_reason", "stop"}, {"delta", json::object()}}})}}, + {{"usage", {{"completion_tokens", 3}}}}}); + json out = wrap_stream_chunk(final_chunks, true); + EXPECT_TRUE(out.at("stop").get()); + ASSERT_TRUE(out.at("data").is_array()); + EXPECT_EQ(out.at("data").size(), 2u); +} + +TEST(WrapStreamChunk, StopFlagPropagates) { + EXPECT_TRUE(wrap_stream_chunk(json::object(), true).at("stop").get()); + EXPECT_FALSE(wrap_stream_chunk(json::object(), false).at("stop").get()); +} + +TEST(WrapStreamChunk, NullPayload_DataIsNull) { + json out = wrap_stream_chunk(json(), false); + EXPECT_TRUE(out.at("data").is_null()); + EXPECT_FALSE(out.at("stop").get()); +} + +TEST(WrapStreamChunk, ExactlyTwoKeys) { + json out = wrap_stream_chunk(json::object(), false); + EXPECT_EQ(out.size(), 2u); + EXPECT_TRUE(out.contains("data")); + EXPECT_TRUE(out.contains("stop")); +} diff --git a/src/test/cpp/test_log_helpers.cpp b/src/test/cpp/test_log_helpers.cpp index 4fc6cc8e..f7dadfd9 100644 --- a/src/test/cpp/test_log_helpers.cpp +++ b/src/test/cpp/test_log_helpers.cpp @@ -19,21 +19,13 @@ using json = nlohmann::json; // log_level_name // ============================================================ -TEST(LogLevelName, Error) { - EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_ERROR), "ERROR"); -} +TEST(LogLevelName, Error) { EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_ERROR), "ERROR"); } -TEST(LogLevelName, Warn) { - EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_WARN), "WARN"); -} +TEST(LogLevelName, Warn) { EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_WARN), "WARN"); } -TEST(LogLevelName, Info) { - EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_INFO), "INFO"); -} +TEST(LogLevelName, Info) { EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_INFO), "INFO"); } -TEST(LogLevelName, Debug) { - EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_DEBUG), "DEBUG"); -} +TEST(LogLevelName, Debug) { EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_DEBUG), "DEBUG"); } TEST(LogLevelName, NoneFallsBackToInfo) { // GGML_LOG_LEVEL_NONE is not explicitly mapped; the default arm returns INFO @@ -41,9 +33,7 @@ TEST(LogLevelName, NoneFallsBackToInfo) { EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_NONE), "INFO"); } -TEST(LogLevelName, ContFallsBackToInfo) { - EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_CONT), "INFO"); -} +TEST(LogLevelName, ContFallsBackToInfo) { EXPECT_STREQ(log_level_name(GGML_LOG_LEVEL_CONT), "INFO"); } // ============================================================ // format_log_as_json @@ -52,7 +42,7 @@ TEST(LogLevelName, ContFallsBackToInfo) { TEST(FormatLogAsJson, BasicShape) { const std::string out = format_log_as_json(GGML_LOG_LEVEL_INFO, "hello", 1700000000); const json j = json::parse(out); - EXPECT_EQ(j.at("level").get(), "INFO"); + EXPECT_EQ(j.at("level").get(), "INFO"); EXPECT_EQ(j.at("message").get(), "hello"); EXPECT_EQ(j.at("timestamp").get(), 1700000000); } diff --git a/src/test/cpp/test_server.cpp b/src/test/cpp/test_server.cpp index 0db128ac..872f66d3 100644 --- a/src/test/cpp/test_server.cpp +++ b/src/test/cpp/test_server.cpp @@ -38,14 +38,14 @@ namespace { result_timings make_base_timings() { result_timings t; - t.prompt_n = 10; - t.prompt_ms = 200.0; - t.prompt_per_token_ms = 20.0; - t.prompt_per_second = 50.0; - t.predicted_n = 5; - t.predicted_ms = 100.0; + t.prompt_n = 10; + t.prompt_ms = 200.0; + t.prompt_per_token_ms = 20.0; + t.prompt_per_second = 50.0; + t.predicted_n = 5; + t.predicted_ms = 100.0; t.predicted_per_token_ms = 20.0; - t.predicted_per_second = 50.0; + t.predicted_per_second = 50.0; return t; } @@ -76,9 +76,9 @@ TEST(ResultTimings, BaseFieldValues_MatchInput) { result_timings t = make_base_timings(); const json j = t.to_json(); - EXPECT_EQ(j.at("prompt_n").get(), 10); + EXPECT_EQ(j.at("prompt_n").get(), 10); EXPECT_EQ(j.at("predicted_n").get(), 5); - EXPECT_DOUBLE_EQ(j.at("prompt_ms").get(), 200.0); + EXPECT_DOUBLE_EQ(j.at("prompt_ms").get(), 200.0); EXPECT_DOUBLE_EQ(j.at("predicted_per_second").get(), 50.0); } @@ -89,44 +89,40 @@ TEST(ResultTimings, WithoutSpeculative_DraftFieldsAbsent) { const json j = t.to_json(); - EXPECT_FALSE(j.contains("draft_n")) - << "draft_n must be absent when draft_n == 0"; - EXPECT_FALSE(j.contains("draft_n_accepted")) - << "draft_n_accepted must be absent when draft_n == 0"; + EXPECT_FALSE(j.contains("draft_n")) << "draft_n must be absent when draft_n == 0"; + EXPECT_FALSE(j.contains("draft_n_accepted")) << "draft_n_accepted must be absent when draft_n == 0"; } TEST(ResultTimings, WithSpeculative_DraftFieldsPresent) { result_timings t = make_base_timings(); - t.draft_n = 50; + t.draft_n = 50; t.draft_n_accepted = 35; const json j = t.to_json(); - EXPECT_TRUE(j.contains("draft_n")) - << "draft_n must be present when draft_n > 0"; - EXPECT_TRUE(j.contains("draft_n_accepted")) - << "draft_n_accepted must be present when draft_n > 0"; - EXPECT_EQ(j.at("draft_n").get(), 50); + EXPECT_TRUE(j.contains("draft_n")) << "draft_n must be present when draft_n > 0"; + EXPECT_TRUE(j.contains("draft_n_accepted")) << "draft_n_accepted must be present when draft_n > 0"; + EXPECT_EQ(j.at("draft_n").get(), 50); EXPECT_EQ(j.at("draft_n_accepted").get(), 35); } TEST(ResultTimings, DraftNOne_FieldsPresent) { // Edge case: even a single speculative token triggers the fields result_timings t = make_base_timings(); - t.draft_n = 1; + t.draft_n = 1; t.draft_n_accepted = 0; const json j = t.to_json(); EXPECT_TRUE(j.contains("draft_n")); EXPECT_TRUE(j.contains("draft_n_accepted")); - EXPECT_EQ(j.at("draft_n").get(), 1); + EXPECT_EQ(j.at("draft_n").get(), 1); EXPECT_EQ(j.at("draft_n_accepted").get(), 0); } TEST(ResultTimings, DraftFieldsAbsent_WhenExplicitlyZero) { result_timings t = make_base_timings(); - t.draft_n = 0; + t.draft_n = 0; t.draft_n_accepted = 0; const json j = t.to_json(); @@ -170,8 +166,7 @@ TEST(SlotParamsToJson, NewChatSyntaxFields_Present) { task_params p; const json j = p.to_json(); - EXPECT_TRUE(j.contains("chat_format")) - << "chat_format must come from oaicompat_chat_syntax.format"; + EXPECT_TRUE(j.contains("chat_format")) << "chat_format must come from oaicompat_chat_syntax.format"; EXPECT_TRUE(j.contains("reasoning_format")) << "reasoning_format must come from oaicompat_chat_syntax.reasoning_format"; EXPECT_TRUE(j.contains("reasoning_in_content")) @@ -185,8 +180,7 @@ TEST(SlotParamsToJson, OldChatFormatEnum_NotPresent) { task_params p; const json j = p.to_json(); - EXPECT_FALSE(j.contains("oaicompat_chat_format")) - << "Legacy oaicompat_chat_format field must not appear in b8576"; + EXPECT_FALSE(j.contains("oaicompat_chat_format")) << "Legacy oaicompat_chat_format field must not appear in b8576"; } TEST(SlotParamsToJson, GrammarValue_EmptyByDefault) { @@ -274,7 +268,7 @@ TEST(SlotParamsToJson, Lora_PopulatedEntries) { const json j = p.to_json(); // Each entry is {id, scale}; order not guaranteed — build a map to verify ASSERT_EQ(j.at("lora").size(), 2u); - std::map got; + std::map got; for (const auto &entry : j.at("lora")) { got[entry.at("id").get()] = entry.at("scale").get(); } @@ -286,7 +280,7 @@ TEST(SlotParamsToJson, GrammarTriggers_SerialiseViaServerGrammarTrigger) { task_params p; // Add a WORD trigger — must be serialised through server_grammar_trigger common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; trigger.value = "```json"; p.sampling.grammar_triggers.push_back(trigger); @@ -389,9 +383,7 @@ TEST(StrToBytes, AsciiChars) { EXPECT_EQ(bytes[2].get(), static_cast('C')); } -TEST(StrToBytes, EmptyString) { - EXPECT_TRUE(str_to_bytes("").empty()); -} +TEST(StrToBytes, EmptyString) { EXPECT_TRUE(str_to_bytes("").empty()); } TEST(StrToBytes, HighByte) { // Byte 0xFF must survive the conversion unchanged @@ -402,9 +394,13 @@ TEST(StrToBytes, HighByte) { TEST(CompletionTokenOutput, ToJson_PostSampling_UsesProbLabel) { completion_token_output cto; - cto.tok = 1; cto.prob = 0.5f; cto.text_to_send = "hi"; + cto.tok = 1; + cto.prob = 0.5f; + cto.text_to_send = "hi"; completion_token_output::prob_info pi; - pi.tok = 1; pi.txt = "hi"; pi.prob = 0.5f; + pi.tok = 1; + pi.txt = "hi"; + pi.prob = 0.5f; cto.probs.push_back(pi); const json j = cto.to_json(/*post_sampling_probs=*/true); @@ -417,9 +413,13 @@ TEST(CompletionTokenOutput, ToJson_PostSampling_UsesProbLabel) { TEST(CompletionTokenOutput, ToJson_PreSampling_UsesLogprobLabel) { completion_token_output cto; - cto.tok = 2; cto.prob = 0.25f; cto.text_to_send = "x"; + cto.tok = 2; + cto.prob = 0.25f; + cto.text_to_send = "x"; completion_token_output::prob_info pi; - pi.tok = 2; pi.txt = "x"; pi.prob = 0.25f; + pi.tok = 2; + pi.txt = "x"; + pi.prob = 0.25f; cto.probs.push_back(pi); const json j = cto.to_json(/*post_sampling_probs=*/false); @@ -437,7 +437,9 @@ TEST(CompletionTokenOutput, ProbsVectorToJson_Empty_ReturnsEmptyArray) { TEST(CompletionTokenOutput, ProbsVectorToJson_TokenFields) { completion_token_output cto; - cto.tok = 7; cto.prob = 1.0f; cto.text_to_send = "ok"; + cto.tok = 7; + cto.prob = 1.0f; + cto.text_to_send = "ok"; const json j = completion_token_output::probs_vector_to_json({cto}, true); ASSERT_EQ(j.size(), 1u); EXPECT_EQ(j[0].at("id").get(), 7); @@ -453,8 +455,8 @@ TEST(CompletionTokenOutput, ProbsVectorToJson_TokenFields) { TEST(ServerTaskResultRerank, ToJson_AllFieldsPresent) { server_task_result_rerank r; - r.index = 3; - r.score = 0.87f; + r.index = 3; + r.score = 0.87f; r.n_tokens = 42; const json j = r.to_json(); @@ -477,7 +479,7 @@ TEST(ServerTaskResultRerank, ToJson_DefaultScore_IsNegativeLarge) { TEST(ServerTaskResultEmbd, NonOaicompat_ShapeCorrect) { server_task_result_embd e; - e.index = 1; + e.index = 1; e.embedding = {{0.1f, 0.2f}, {0.3f, 0.4f}}; e.n_tokens = 5; e.res_type = TASK_RESPONSE_TYPE_NONE; @@ -491,7 +493,7 @@ TEST(ServerTaskResultEmbd, NonOaicompat_ShapeCorrect) { TEST(ServerTaskResultEmbd, Oaicompat_UsesFirstRow) { server_task_result_embd e; - e.index = 0; + e.index = 0; e.embedding = {{1.0f, 2.0f}, {3.0f, 4.0f}}; e.n_tokens = 8; e.res_type = TASK_RESPONSE_TYPE_OAI_EMBD; @@ -499,7 +501,7 @@ TEST(ServerTaskResultEmbd, Oaicompat_UsesFirstRow) { const json j = e.to_json(); // OAI compat exposes only embedding[0] ASSERT_TRUE(j.at("embedding").is_array()); - EXPECT_EQ(j.at("embedding").size(), 2u); // first row has 2 elements + EXPECT_EQ(j.at("embedding").size(), 2u); // first row has 2 elements EXPECT_FLOAT_EQ(j.at("embedding")[0].get(), 1.0f); EXPECT_EQ(j.at("tokens_evaluated").get(), 8); } @@ -508,8 +510,8 @@ TEST(ServerTaskResultEmbd, NonOaicompat_NTokensAbsent) { // tokens_evaluated must not appear in the non-OAI shape server_task_result_embd e; e.embedding = {{0.5f}}; - e.n_tokens = 3; - e.res_type = TASK_RESPONSE_TYPE_NONE; + e.n_tokens = 3; + e.res_type = TASK_RESPONSE_TYPE_NONE; const json j = e.to_json(); EXPECT_FALSE(j.contains("tokens_evaluated")); } @@ -518,9 +520,9 @@ TEST(ServerTaskResultEmbd, NonOaicompat_SingleRowValues) { // Verify the float values survive the JSON round-trip server_task_result_embd e; e.embedding = {{0.1f, 0.2f, 0.3f}}; - e.res_type = TASK_RESPONSE_TYPE_NONE; + e.res_type = TASK_RESPONSE_TYPE_NONE; const json j = e.to_json(); - ASSERT_EQ(j.at("embedding").size(), 1u); // one row + ASSERT_EQ(j.at("embedding").size(), 1u); // one row ASSERT_EQ(j.at("embedding")[0].size(), 3u); // three elements EXPECT_FLOAT_EQ(j.at("embedding")[0][1].get(), 0.2f); } @@ -529,7 +531,7 @@ TEST(ServerTaskResultEmbd, Dispatcher_NoneRoutes_ToNonOaicompat) { // to_json() dispatches on res_type; NONE → non-oaicompat (full matrix) server_task_result_embd e; e.embedding = {{1.0f, 2.0f}, {3.0f, 4.0f}}; - e.res_type = TASK_RESPONSE_TYPE_NONE; + e.res_type = TASK_RESPONSE_TYPE_NONE; const json j = e.to_json(); EXPECT_EQ(j.at("embedding").size(), 2u); // full 2D matrix EXPECT_FALSE(j.contains("tokens_evaluated")); @@ -541,7 +543,11 @@ TEST(ServerTaskResultEmbd, Dispatcher_NoneRoutes_ToNonOaicompat) { // ============================================================ namespace { -struct ErrorCase { error_type type; int code; std::string type_str; }; +struct ErrorCase { + error_type type; + int code; + std::string type_str; +}; } // namespace TEST(FormatErrorResponse, InvalidRequest_400) { @@ -594,37 +600,101 @@ TEST(FormatErrorResponse, NotSupported_501) { // ============================================================ TEST(ServerTaskTypeHelpers, NeedEmbd_TrueForEmbeddingAndRerank) { - { server_task t; t.type = SERVER_TASK_TYPE_EMBEDDING; EXPECT_TRUE(t.need_embd()); } - { server_task t; t.type = SERVER_TASK_TYPE_RERANK; EXPECT_TRUE(t.need_embd()); } + { + server_task t; + t.type = SERVER_TASK_TYPE_EMBEDDING; + EXPECT_TRUE(t.need_embd()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_RERANK; + EXPECT_TRUE(t.need_embd()); + } } TEST(ServerTaskTypeHelpers, NeedEmbd_FalseForOtherTypes) { - { server_task t; t.type = SERVER_TASK_TYPE_COMPLETION; EXPECT_FALSE(t.need_embd()); } - { server_task t; t.type = SERVER_TASK_TYPE_INFILL; EXPECT_FALSE(t.need_embd()); } - { server_task t; t.type = SERVER_TASK_TYPE_METRICS; EXPECT_FALSE(t.need_embd()); } - { server_task t; t.type = SERVER_TASK_TYPE_CANCEL; EXPECT_FALSE(t.need_embd()); } + { + server_task t; + t.type = SERVER_TASK_TYPE_COMPLETION; + EXPECT_FALSE(t.need_embd()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_INFILL; + EXPECT_FALSE(t.need_embd()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_METRICS; + EXPECT_FALSE(t.need_embd()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_CANCEL; + EXPECT_FALSE(t.need_embd()); + } } TEST(ServerTaskTypeHelpers, NeedLogits_TrueForCompletionAndInfill) { - { server_task t; t.type = SERVER_TASK_TYPE_COMPLETION; EXPECT_TRUE(t.need_logits()); } - { server_task t; t.type = SERVER_TASK_TYPE_INFILL; EXPECT_TRUE(t.need_logits()); } + { + server_task t; + t.type = SERVER_TASK_TYPE_COMPLETION; + EXPECT_TRUE(t.need_logits()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_INFILL; + EXPECT_TRUE(t.need_logits()); + } } TEST(ServerTaskTypeHelpers, NeedLogits_FalseForOtherTypes) { - { server_task t; t.type = SERVER_TASK_TYPE_EMBEDDING; EXPECT_FALSE(t.need_logits()); } - { server_task t; t.type = SERVER_TASK_TYPE_RERANK; EXPECT_FALSE(t.need_logits()); } - { server_task t; t.type = SERVER_TASK_TYPE_METRICS; EXPECT_FALSE(t.need_logits()); } + { + server_task t; + t.type = SERVER_TASK_TYPE_EMBEDDING; + EXPECT_FALSE(t.need_logits()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_RERANK; + EXPECT_FALSE(t.need_logits()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_METRICS; + EXPECT_FALSE(t.need_logits()); + } } TEST(ServerTaskTypeHelpers, NeedSampling_TrueForCompletionAndInfill) { - { server_task t; t.type = SERVER_TASK_TYPE_COMPLETION; EXPECT_TRUE(t.need_sampling()); } - { server_task t; t.type = SERVER_TASK_TYPE_INFILL; EXPECT_TRUE(t.need_sampling()); } + { + server_task t; + t.type = SERVER_TASK_TYPE_COMPLETION; + EXPECT_TRUE(t.need_sampling()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_INFILL; + EXPECT_TRUE(t.need_sampling()); + } } TEST(ServerTaskTypeHelpers, NeedSampling_FalseForNonGenerativeTasks) { - { server_task t; t.type = SERVER_TASK_TYPE_EMBEDDING; EXPECT_FALSE(t.need_sampling()); } - { server_task t; t.type = SERVER_TASK_TYPE_RERANK; EXPECT_FALSE(t.need_sampling()); } - { server_task t; t.type = SERVER_TASK_TYPE_METRICS; EXPECT_FALSE(t.need_sampling()); } + { + server_task t; + t.type = SERVER_TASK_TYPE_EMBEDDING; + EXPECT_FALSE(t.need_sampling()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_RERANK; + EXPECT_FALSE(t.need_sampling()); + } + { + server_task t; + t.type = SERVER_TASK_TYPE_METRICS; + EXPECT_FALSE(t.need_sampling()); + } } // ============================================================ @@ -652,20 +722,20 @@ TEST(ServerTaskNTokens, PopulatedTokens_ReturnsCount) { namespace { server_task_result_metrics make_metrics() { server_task_result_metrics m; - m.n_idle_slots = 2; + m.n_idle_slots = 2; m.n_processing_slots = 1; - m.n_tasks_deferred = 3; - m.t_start = 1234567890LL; + m.n_tasks_deferred = 3; + m.t_start = 1234567890LL; m.n_prompt_tokens_processed_total = 100; - m.t_prompt_processing_total = 50; - m.n_tokens_predicted_total = 200; - m.t_tokens_generation_total = 80; - m.n_prompt_tokens_processed = 10; - m.t_prompt_processing = 5; - m.n_tokens_predicted = 20; - m.t_tokens_generation = 8; - m.n_decode_total = 300; - m.n_busy_slots_total = 4; + m.t_prompt_processing_total = 50; + m.n_tokens_predicted_total = 200; + m.t_tokens_generation_total = 80; + m.n_prompt_tokens_processed = 10; + m.t_prompt_processing = 5; + m.n_tokens_predicted = 20; + m.t_tokens_generation = 8; + m.n_decode_total = 300; + m.n_busy_slots_total = 4; return m; } } // namespace @@ -720,12 +790,12 @@ TEST(ServerTaskResultMetrics, ToJson_SlotDataIsArray) { TEST(ServerTaskResultSlotSaveLoad, SaveMode_CorrectFields) { server_task_result_slot_save_load r; - r.id_slot = 0; + r.id_slot = 0; r.filename = "slot_0.bin"; - r.is_save = true; + r.is_save = true; r.n_tokens = 128; - r.n_bytes = 4096; - r.t_ms = 12.5; + r.n_bytes = 4096; + r.t_ms = 12.5; const json j = r.to_json(); EXPECT_EQ(j.at("filename").get(), "slot_0.bin"); @@ -739,12 +809,12 @@ TEST(ServerTaskResultSlotSaveLoad, SaveMode_CorrectFields) { TEST(ServerTaskResultSlotSaveLoad, LoadMode_CorrectFields) { server_task_result_slot_save_load r; - r.id_slot = 1; + r.id_slot = 1; r.filename = "slot_1.bin"; - r.is_save = false; + r.is_save = false; r.n_tokens = 64; - r.n_bytes = 2048; - r.t_ms = 7.3; + r.n_bytes = 2048; + r.t_ms = 7.3; const json j = r.to_json(); EXPECT_EQ(j.at("n_restored").get(), 64u); @@ -762,7 +832,7 @@ TEST(ServerTaskResultSlotSaveLoad, LoadMode_CorrectFields) { TEST(ServerTaskResultSlotErase, ToJson_NErasedPresent) { server_task_result_slot_erase r; - r.id_slot = 2; + r.id_slot = 2; r.n_erased = 512; const json j = r.to_json(); @@ -787,7 +857,7 @@ TEST(ServerTaskResultApplyLora, ToJson_SuccessTrue) { TEST(ServerTaskResultError, StandardError_HasMessageField) { server_task_result_error e; e.err_type = ERROR_TYPE_SERVER; - e.err_msg = "something went wrong"; + e.err_msg = "something went wrong"; const json j = e.to_json(); EXPECT_EQ(j.at("message").get(), "something went wrong"); } @@ -795,7 +865,7 @@ TEST(ServerTaskResultError, StandardError_HasMessageField) { TEST(ServerTaskResultError, StandardError_HasCodeAndType) { server_task_result_error e; e.err_type = ERROR_TYPE_INVALID_REQUEST; - e.err_msg = "bad param"; + e.err_msg = "bad param"; const json j = e.to_json(); EXPECT_EQ(j.at("code").get(), 400); EXPECT_EQ(j.at("type").get(), "invalid_request_error"); @@ -808,10 +878,10 @@ TEST(ServerTaskResultError, IsError_ReturnsTrue) { TEST(ServerTaskResultError, ExceedContextSize_AddsExtraFields) { server_task_result_error e; - e.err_type = ERROR_TYPE_EXCEED_CONTEXT_SIZE; - e.err_msg = "context full"; + e.err_type = ERROR_TYPE_EXCEED_CONTEXT_SIZE; + e.err_msg = "context full"; e.n_prompt_tokens = 512; - e.n_ctx = 256; + e.n_ctx = 256; const json j = e.to_json(); EXPECT_EQ(j.at("n_prompt_tokens").get(), 512); EXPECT_EQ(j.at("n_ctx").get(), 256); @@ -820,7 +890,7 @@ TEST(ServerTaskResultError, ExceedContextSize_AddsExtraFields) { TEST(ServerTaskResultError, DefaultError_NoExtraContextFields) { server_task_result_error e; e.err_type = ERROR_TYPE_SERVER; - e.err_msg = "fail"; + e.err_msg = "fail"; const json j = e.to_json(); EXPECT_FALSE(j.contains("n_prompt_tokens")); EXPECT_FALSE(j.contains("n_ctx")); @@ -834,13 +904,13 @@ TEST(ServerTaskResultError, DefaultError_NoExtraContextFields) { TEST(ResultPromptProgress, ToJson_AllFourFields) { result_prompt_progress p; - p.total = 100; - p.cache = 40; + p.total = 100; + p.cache = 40; p.processed = 60; - p.time_ms = 1234; + p.time_ms = 1234; const json j = p.to_json(); - EXPECT_EQ(j.at("total").get(), 100); - EXPECT_EQ(j.at("cache").get(), 40); + EXPECT_EQ(j.at("total").get(), 100); + EXPECT_EQ(j.at("cache").get(), 40); EXPECT_EQ(j.at("processed").get(), 60); EXPECT_EQ(j.at("time_ms").get(), 1234); } @@ -848,8 +918,8 @@ TEST(ResultPromptProgress, ToJson_AllFourFields) { TEST(ResultPromptProgress, ToJson_DefaultZeros) { result_prompt_progress p; const json j = p.to_json(); - EXPECT_EQ(j.at("total").get(), 0); - EXPECT_EQ(j.at("cache").get(), 0); + EXPECT_EQ(j.at("total").get(), 0); + EXPECT_EQ(j.at("cache").get(), 0); EXPECT_EQ(j.at("processed").get(), 0); EXPECT_EQ(j.at("time_ms").get(), 0); } @@ -864,10 +934,10 @@ TEST(ResultPromptProgress, ToJson_DefaultZeros) { TEST(ServerTaskResultCmplPartial, NonOaicompat_CoreFields) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; - p.content = "hello"; - p.n_decoded = 3; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_NONE; + p.content = "hello"; + p.n_decoded = 3; p.n_prompt_tokens = 10; const json j = p.to_json_non_oaicompat(); @@ -881,7 +951,7 @@ TEST(ServerTaskResultCmplPartial, NonOaicompat_CoreFields) { TEST(ServerTaskResultCmplPartial, NonOaicompat_TimingsAbsentByDefault) { server_task_result_cmpl_partial p; p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; + p.res_type = TASK_RESPONSE_TYPE_NONE; // timings.prompt_n == 0 by default → timings should be absent const json j = p.to_json_non_oaicompat(); EXPECT_FALSE(j.contains("timings")); @@ -889,8 +959,8 @@ TEST(ServerTaskResultCmplPartial, NonOaicompat_TimingsAbsentByDefault) { TEST(ServerTaskResultCmplPartial, NonOaicompat_TimingsPresentWhenPromptNNonzero) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_NONE; p.timings.prompt_n = 5; const json j = p.to_json_non_oaicompat(); EXPECT_TRUE(j.contains("timings")); @@ -898,19 +968,19 @@ TEST(ServerTaskResultCmplPartial, NonOaicompat_TimingsPresentWhenPromptNNonzero) TEST(ServerTaskResultCmplPartial, NonOaicompat_ProgressAbsentWhenNotProgress) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_NONE; p.is_progress = false; - const json j = p.to_json_non_oaicompat(); + const json j = p.to_json_non_oaicompat(); EXPECT_FALSE(j.contains("prompt_progress")); } TEST(ServerTaskResultCmplPartial, NonOaicompat_ProgressPresentWhenIsProgress) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; - p.is_progress = true; - p.progress.total = 20; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_NONE; + p.is_progress = true; + p.progress.total = 20; p.progress.processed = 10; const json j = p.to_json_non_oaicompat(); ASSERT_TRUE(j.contains("prompt_progress")); @@ -925,8 +995,8 @@ TEST(ServerTaskResultCmplPartial, IsStop_ReturnsFalse) { TEST(ServerTaskResultCmplPartial, NonOaicompat_IdSlotField) { server_task_result_cmpl_partial p; p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; - p.id_slot = 3; + p.res_type = TASK_RESPONSE_TYPE_NONE; + p.id_slot = 3; const json j = p.to_json_non_oaicompat(); EXPECT_EQ(j.at("id_slot").get(), 3); } @@ -934,7 +1004,7 @@ TEST(ServerTaskResultCmplPartial, NonOaicompat_IdSlotField) { TEST(ServerTaskResultCmplPartial, NonOaicompat_CompletionProbabilitiesAbsentWhenProbsEmpty) { server_task_result_cmpl_partial p; p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; + p.res_type = TASK_RESPONSE_TYPE_NONE; // prob_output.probs is empty by default const json j = p.to_json_non_oaicompat(); EXPECT_FALSE(j.contains("completion_probabilities")); @@ -942,11 +1012,13 @@ TEST(ServerTaskResultCmplPartial, NonOaicompat_CompletionProbabilitiesAbsentWhen TEST(ServerTaskResultCmplPartial, NonOaicompat_CompletionProbabilitiesPresentWhenProbsSet) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_NONE; p.post_sampling_probs = true; completion_token_output::prob_info pi; - pi.tok = 5; pi.txt = "hi"; pi.prob = 0.8f; + pi.tok = 5; + pi.txt = "hi"; + pi.prob = 0.8f; p.prob_output.probs.push_back(pi); const json j = p.to_json_non_oaicompat(); ASSERT_TRUE(j.contains("completion_probabilities")); @@ -966,8 +1038,8 @@ TEST(ServerTaskResultCmplFinal, IsStop_ReturnsTrue) { TEST(ServerTaskResultCmplFinal, NonOaicompat_StopAlwaysTrue) { server_task_result_cmpl_final f; - f.content = "done"; - f.n_decoded = 3; + f.content = "done"; + f.n_decoded = 3; f.n_prompt_tokens = 7; const json j = f.to_json_non_oaicompat(); EXPECT_TRUE(j.at("stop").get()); @@ -992,7 +1064,7 @@ TEST(ServerTaskResultCmplFinal, NonOaicompat_StopType_Eos) { TEST(ServerTaskResultCmplFinal, NonOaicompat_StopType_Word) { server_task_result_cmpl_final f; - f.stop = STOP_TYPE_WORD; + f.stop = STOP_TYPE_WORD; f.stopping_word = ""; const json j = f.to_json_non_oaicompat(); EXPECT_EQ(j.at("stop_type").get(), "word"); @@ -1019,10 +1091,12 @@ TEST(ServerTaskResultCmplFinal, NonOaicompat_NoProbsOutput_CompletionProbabiliti TEST(ServerTaskResultCmplFinal, NonOaicompat_WithProbsOutput_CompletionProbabilitiesPresent) { // When probs_output is non-empty and stream==false, the key must appear. server_task_result_cmpl_final f; - f.stream = false; + f.stream = false; f.post_sampling_probs = true; completion_token_output cto; - cto.tok = 42; cto.prob = 0.9f; cto.text_to_send = "hi"; + cto.tok = 42; + cto.prob = 0.9f; + cto.text_to_send = "hi"; f.probs_output.push_back(cto); const json j = f.to_json_non_oaicompat(); ASSERT_TRUE(j.contains("completion_probabilities")); @@ -1032,10 +1106,12 @@ TEST(ServerTaskResultCmplFinal, NonOaicompat_WithProbsOutput_CompletionProbabili TEST(ServerTaskResultCmplFinal, NonOaicompat_StreamModeWithProbs_CompletionProbabilitiesAbsent) { // stream==true suppresses completion_probabilities even if probs_output is set. server_task_result_cmpl_final f; - f.stream = true; + f.stream = true; f.post_sampling_probs = true; completion_token_output cto; - cto.tok = 1; cto.prob = 0.5f; cto.text_to_send = "x"; + cto.tok = 1; + cto.prob = 0.5f; + cto.text_to_send = "x"; f.probs_output.push_back(cto); const json j = f.to_json_non_oaicompat(); EXPECT_FALSE(j.contains("completion_probabilities")); @@ -1049,19 +1125,19 @@ TEST(ServerTaskResultCmplFinal, NonOaicompat_StreamModeWithProbs_CompletionProba TEST(ServerTaskResultCmplFinal, UsageJsonOaicompat_FieldsCorrect) { server_task_result_cmpl_final f; - f.n_decoded = 17; - f.n_prompt_tokens = 8; - f.n_prompt_tokens_cache = 3; + f.n_decoded = 17; + f.n_prompt_tokens = 8; + f.n_prompt_tokens_cache = 3; const json j = f.usage_json_oaicompat(); EXPECT_EQ(j.at("completion_tokens").get(), 17); EXPECT_EQ(j.at("prompt_tokens").get(), 8); - EXPECT_EQ(j.at("total_tokens").get(), 25); // 17 + 8 + EXPECT_EQ(j.at("total_tokens").get(), 25); // 17 + 8 EXPECT_EQ(j.at("prompt_tokens_details").at("cached_tokens").get(), 3); } TEST(ServerTaskResultCmplFinal, UsageJsonOaicompat_TotalTokensIsSumOfBoth) { server_task_result_cmpl_final f; - f.n_decoded = 5; + f.n_decoded = 5; f.n_prompt_tokens = 10; const json j = f.usage_json_oaicompat(); EXPECT_EQ(j.at("total_tokens").get(), f.n_decoded + f.n_prompt_tokens); @@ -1077,10 +1153,10 @@ TEST(ServerTaskResultCmplFinal, UsageJsonOaicompat_TotalTokensIsSumOfBoth) { namespace { server_task_result_cmpl_final make_oai_final(const std::string &content = "hello") { server_task_result_cmpl_final f; - f.content = content; + f.content = content; f.oaicompat_model = "test-model"; f.oaicompat_cmpl_id = "cmpl-test"; - f.n_decoded = 3; + f.n_decoded = 3; f.n_prompt_tokens = 5; return f; } @@ -1184,7 +1260,7 @@ TEST(CmplFinalOaicompatChat, Usage_Present) { TEST(CmplFinalOaicompatChat, WithExplicitOaicompatMsg_MessageContentUsed) { auto f = make_oai_final("ignored"); - f.oaicompat_msg.role = "assistant"; + f.oaicompat_msg.role = "assistant"; f.oaicompat_msg.content = "explicit reply"; const json j = f.to_json_oaicompat_chat(); EXPECT_EQ(j.at("choices")[0].at("message").at("content").get(), "explicit reply"); @@ -1195,8 +1271,8 @@ TEST(CmplFinalOaicompatChat, WithToolCalls_FinishReason_IsToolCalls) { // be "tool_calls" (not "stop"). auto f = make_oai_final(""); common_chat_tool_call tc; - tc.id = "call_1"; - tc.name = "search"; + tc.id = "call_1"; + tc.name = "search"; tc.arguments = R"({"q":"test"})"; f.oaicompat_msg.tool_calls.push_back(tc); f.stop = STOP_TYPE_EOS; @@ -1207,8 +1283,8 @@ TEST(CmplFinalOaicompatChat, WithToolCalls_FinishReason_IsToolCalls) { TEST(CmplFinalOaicompatChat, WithToolCalls_MessageHasToolCallsArray) { auto f = make_oai_final(""); common_chat_tool_call tc; - tc.id = "call_1"; - tc.name = "search"; + tc.id = "call_1"; + tc.name = "search"; tc.arguments = R"({"q":"test"})"; f.oaicompat_msg.tool_calls.push_back(tc); const json j = f.to_json_oaicompat_chat(); @@ -1243,9 +1319,9 @@ TEST(CmplFinalAnthropic, StopReason_EndTurnForEos) { TEST(CmplFinalAnthropic, StopReason_EndTurnForWord) { auto f = make_oai_final(); - f.stop = STOP_TYPE_WORD; + f.stop = STOP_TYPE_WORD; f.stopping_word = ""; - const json j = f.to_json_anthropic(); + const json j = f.to_json_anthropic(); EXPECT_EQ(j.at("stop_reason").get(), "end_turn"); } @@ -1257,32 +1333,35 @@ TEST(CmplFinalAnthropic, StopSequence_NullWhenEmpty) { TEST(CmplFinalAnthropic, StopSequence_ReflectsStoppingWord) { auto f = make_oai_final(); - f.stop = STOP_TYPE_WORD; + f.stop = STOP_TYPE_WORD; f.stopping_word = ""; f.oaicompat_msg.content = "done"; - const json j = f.to_json_anthropic(); + const json j = f.to_json_anthropic(); EXPECT_EQ(j.at("stop_sequence").get(), ""); } TEST(CmplFinalAnthropic, ContentBlock_TextBlockForPlainContent) { auto f = make_oai_final("plain text"); - const json j = f.to_json_anthropic(); + const json j = f.to_json_anthropic(); const json &blks = j.at("content"); ASSERT_FALSE(blks.empty()); // last block is the text block when no reasoning bool found_text = false; for (const auto &b : blks) { - if (b.at("type").get() == "text") { found_text = true; break; } + if (b.at("type").get() == "text") { + found_text = true; + break; + } } EXPECT_TRUE(found_text); } TEST(CmplFinalAnthropic, ContentBlock_ThinkingBlockFirst) { auto f = make_oai_final("answer"); - f.oaicompat_msg.role = "assistant"; - f.oaicompat_msg.content = "answer"; + f.oaicompat_msg.role = "assistant"; + f.oaicompat_msg.content = "answer"; f.oaicompat_msg.reasoning_content = "step by step"; - const json j = f.to_json_anthropic(); + const json j = f.to_json_anthropic(); const json &blks = j.at("content"); ASSERT_GE(blks.size(), 2u); EXPECT_EQ(blks[0].at("type").get(), "thinking"); @@ -1292,18 +1371,18 @@ TEST(CmplFinalAnthropic, ContentBlock_ThinkingBlockFirst) { TEST(CmplFinalAnthropic, ContentBlock_ToolUseBlock) { auto f = make_oai_final(""); common_chat_tool_call tc; - tc.id = "call_1"; - tc.name = "get_weather"; + tc.id = "call_1"; + tc.name = "get_weather"; tc.arguments = R"({"city":"Paris"})"; f.oaicompat_msg.tool_calls.push_back(tc); f.stop = STOP_TYPE_EOS; - const json j = f.to_json_anthropic(); + const json j = f.to_json_anthropic(); EXPECT_EQ(j.at("stop_reason").get(), "tool_use"); bool found_tool = false; for (const auto &b : j.at("content")) { if (b.at("type").get() == "tool_use") { EXPECT_EQ(b.at("name").get(), "get_weather"); - EXPECT_EQ(b.at("id").get(), "call_1"); + EXPECT_EQ(b.at("id").get(), "call_1"); EXPECT_EQ(b.at("input").at("city").get(), "Paris"); found_tool = true; } @@ -1321,10 +1400,10 @@ TEST(CmplFinalAnthropic, ContentBlock_ToolUseBlock) { namespace { server_task_result_cmpl_partial make_partial(const std::string &content = "tok") { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_OAI_CMPL; - p.content = content; - p.oaicompat_model = "test-model"; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_OAI_CMPL; + p.content = content; + p.oaicompat_model = "test-model"; p.oaicompat_cmpl_id = "cmpl-part"; return p; } @@ -1362,7 +1441,9 @@ TEST(CmplPartialOaicompat, LogProbs_NonEmptyProbs_HasContentArray) { // When probs are set, logprobs becomes {"content": [...]} (not null) auto p = make_partial(); completion_token_output::prob_info pi; - pi.tok = 5; pi.txt = "hi"; pi.prob = 0.8f; + pi.tok = 5; + pi.txt = "hi"; + pi.prob = 0.8f; p.prob_output.probs.push_back(pi); const json j = p.to_json_oaicompat(); ASSERT_FALSE(j.at("choices")[0].at("logprobs").is_null()); @@ -1381,19 +1462,19 @@ TEST(CmplPartialOaicompat, LogProbs_NonEmptyProbs_HasContentArray) { TEST(CmplPartialToJsonDispatch, ResTypeNone_RoutesToNonOaicompat) { server_task_result_cmpl_partial p; p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; - p.content = "hello"; - const json j = p.to_json(); // must not assert/abort + p.res_type = TASK_RESPONSE_TYPE_NONE; + p.content = "hello"; + const json j = p.to_json(); // must not assert/abort // non-oaicompat shape has "content" directly EXPECT_EQ(j.at("content").get(), "hello"); } TEST(CmplPartialToJsonDispatch, ResTypeOaiCmpl_RoutesToOaicompat) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_OAI_CMPL; - p.content = "hi"; - p.oaicompat_model = "m"; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_OAI_CMPL; + p.content = "hi"; + p.oaicompat_model = "m"; p.oaicompat_cmpl_id = "c"; const json j = p.to_json(); // oaicompat shape wraps content inside choices @@ -1407,7 +1488,7 @@ TEST(CmplPartialToJsonDispatch, NotUpdated_Asserts) { // so we verify the flag semantics by checking the truthy case passes. // (The death test would require EXPECT_DEATH which needs signal handling.) p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_NONE; + p.res_type = TASK_RESPONSE_TYPE_NONE; EXPECT_NO_THROW(p.to_json()); } @@ -1416,10 +1497,10 @@ TEST(CmplPartialToJsonDispatch, ResTypeAnthropic_RoutesToAnthropicStream) { // returns a json::array (not a json::object like the OAI arms). // With n_decoded==1 the first-token message_start event is emitted. server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_ANTHROPIC; - p.n_decoded = 1; - p.oaicompat_model = "m"; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_ANTHROPIC; + p.n_decoded = 1; + p.oaicompat_model = "m"; p.oaicompat_cmpl_id = "id"; const json j = p.to_json(); EXPECT_TRUE(j.is_array()); @@ -1436,14 +1517,13 @@ TEST(CmplPartialToJsonDispatch, ResTypeAnthropic_RoutesToAnthropicStream) { namespace { // Minimal final result ready for to_json(); no vocab-dependent fields. -server_task_result_cmpl_final make_dispatched_final(task_response_type rt, - bool stream = false) { +server_task_result_cmpl_final make_dispatched_final(task_response_type rt, bool stream = false) { server_task_result_cmpl_final f; - f.is_updated = true; - f.res_type = rt; - f.stream = stream; - f.content = "hi"; - f.oaicompat_model = "m"; + f.is_updated = true; + f.res_type = rt; + f.stream = stream; + f.content = "hi"; + f.oaicompat_model = "m"; f.oaicompat_cmpl_id = "id"; return f; } @@ -1527,7 +1607,7 @@ TEST(CmplFinalVerboseFlag, Oaicompat_TimingsAbsentByDefault) { TEST(CmplFinalVerboseFlag, Oaicompat_TimingsPresentWhenPromptNNonNeg) { auto f = make_oai_final(); - f.timings.prompt_n = 0; // >= 0 triggers inclusion + f.timings.prompt_n = 0; // >= 0 triggers inclusion const json j = f.to_json_oaicompat(); EXPECT_TRUE(j.contains("timings")); } @@ -1547,10 +1627,10 @@ TEST(CmplFinalVerboseFlag, Oaicompat_TimingsPresentWhenPromptNNonNeg) { namespace { server_task_result_cmpl_final make_stream_final(bool include_usage = false) { server_task_result_cmpl_final f; - f.oaicompat_model = "m"; + f.oaicompat_model = "m"; f.oaicompat_cmpl_id = "id"; - f.stop = STOP_TYPE_EOS; - f.include_usage = include_usage; + f.stop = STOP_TYPE_EOS; + f.include_usage = include_usage; // No oaicompat_msg_diffs → just the single terminal chunk return f; } @@ -1575,7 +1655,7 @@ TEST(CmplFinalChatStream, LastChunk_HasNonNullFinishReason) { const json &last_chunk = j.back(); const json &fr = last_chunk.at("choices")[0].at("finish_reason"); EXPECT_FALSE(fr.is_null()); - EXPECT_EQ(fr.get(), "stop"); // STOP_TYPE_EOS → "stop" + EXPECT_EQ(fr.get(), "stop"); // STOP_TYPE_EOS → "stop" } TEST(CmplFinalChatStream, IncludeUsageFalse_NoUsageChunk) { @@ -1655,8 +1735,7 @@ TEST(ParamsFromJsonCmpl, NDiscard_Negative_ClampedToZero) { } TEST(ParamsFromJsonCmpl, EmptyDrySequenceBreakers_Throws) { - EXPECT_THROW(parse_params({{"dry_sequence_breakers", json::array()}}), - std::runtime_error); + EXPECT_THROW(parse_params({{"dry_sequence_breakers", json::array()}}), std::runtime_error); } TEST(ParamsFromJsonCmpl, LoraNotArray_Throws) { @@ -1668,8 +1747,7 @@ TEST(ParamsFromJsonCmpl, RepeatLastN_BelowMinusOne_Throws) { } TEST(ParamsFromJsonCmpl, StreamOptions_IncludeUsage_Parsed) { - const json data = {{"stream", true}, - {"stream_options", {{"include_usage", true}}}}; + const json data = {{"stream", true}, {"stream_options", {{"include_usage", true}}}}; const auto p = parse_params(data); EXPECT_TRUE(p.include_usage); } @@ -1770,19 +1848,14 @@ TEST(ParamsFromJsonCmpl, ReasoningBudgetTokens_ExplicitMinusOne_Disabled) { TEST(ParamsFromJsonCmpl, JsonSchema_SetsOutputFormatGrammarType) { // json_schema without "grammar" → grammar type OUTPUT_FORMAT - const json data = { - {"json_schema", {{"type", "object"}, {"properties", json::object()}}} - }; + const json data = {{"json_schema", {{"type", "object"}, {"properties", json::object()}}}}; const auto p = parse_params(data); EXPECT_EQ(p.sampling.grammar.type, COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT); } TEST(ParamsFromJsonCmpl, GrammarTypeToolCalls_SetsToolCallsType) { // grammar_type="tool_calls" routes to COMMON_GRAMMAR_TYPE_TOOL_CALLS - const json data = { - {"grammar", "root ::= object"}, - {"grammar_type", "tool_calls"} - }; + const json data = {{"grammar", "root ::= object"}, {"grammar_type", "tool_calls"}}; const auto p = parse_params(data); EXPECT_EQ(p.sampling.grammar.type, COMMON_GRAMMAR_TYPE_TOOL_CALLS); } @@ -1803,8 +1876,8 @@ TEST(ParamsFromJsonCmpl, PlainGrammar_NoGrammarType_SetsUserType) { TEST(CmplFinalResponseFields, EmptyList_AllFieldsPresent) { server_task_result_cmpl_final f; - f.content = "hi"; - f.stop = STOP_TYPE_EOS; + f.content = "hi"; + f.stop = STOP_TYPE_EOS; // response_fields is empty by default → full object returned const json j = f.to_json_non_oaicompat(); EXPECT_TRUE(j.contains("content")); @@ -1814,21 +1887,21 @@ TEST(CmplFinalResponseFields, EmptyList_AllFieldsPresent) { TEST(CmplFinalResponseFields, NonEmptyList_OnlyRequestedFieldsPresent) { server_task_result_cmpl_final f; - f.content = "projected"; + f.content = "projected"; f.response_fields = {"content", "tokens_predicted"}; - const json j = f.to_json_non_oaicompat(); + const json j = f.to_json_non_oaicompat(); EXPECT_TRUE(j.contains("content")); EXPECT_TRUE(j.contains("tokens_predicted")); - EXPECT_FALSE(j.contains("stop_type")); // filtered out - EXPECT_FALSE(j.contains("timings")); // filtered out - EXPECT_FALSE(j.contains("prompt")); // filtered out + EXPECT_FALSE(j.contains("stop_type")); // filtered out + EXPECT_FALSE(j.contains("timings")); // filtered out + EXPECT_FALSE(j.contains("prompt")); // filtered out } TEST(CmplFinalResponseFields, ContentValue_PreservedThroughProjection) { server_task_result_cmpl_final f; - f.content = "keep this"; + f.content = "keep this"; f.response_fields = {"content"}; - const json j = f.to_json_non_oaicompat(); + const json j = f.to_json_non_oaicompat(); EXPECT_EQ(j.at("content").get(), "keep this"); } @@ -1844,10 +1917,10 @@ TEST(CmplFinalResponseFields, ContentValue_PreservedThroughProjection) { namespace { server_task_result_cmpl_partial make_chat_partial(int n_decoded = 1) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_OAI_CHAT; - p.n_decoded = n_decoded; - p.oaicompat_model = "m"; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_OAI_CHAT; + p.n_decoded = n_decoded; + p.oaicompat_model = "m"; p.oaicompat_cmpl_id = "id"; return p; } @@ -1907,8 +1980,8 @@ TEST(CmplPartialOaicompatChat, AllChunks_FinishReasonIsNull) { namespace { server_task_result_cmpl_final make_anthropic_stream_final(stop_type st = STOP_TYPE_EOS) { server_task_result_cmpl_final f; - f.stop = st; - f.oaicompat_model = "m"; + f.stop = st; + f.oaicompat_model = "m"; f.oaicompat_cmpl_id = "id"; return f; } @@ -1926,7 +1999,7 @@ TEST(CmplFinalAnthropicStream, LastEvent_IsMessageStop) { } TEST(CmplFinalAnthropicStream, SecondToLast_IsMessageDelta_WithStopReason) { - const json j = make_anthropic_stream_final(STOP_TYPE_EOS).to_json_anthropic_stream(); + const json j = make_anthropic_stream_final(STOP_TYPE_EOS).to_json_anthropic_stream(); // message_delta is always the penultimate event ASSERT_GE(j.size(), 2u); const json &md = j[j.size() - 2]; @@ -1957,8 +2030,10 @@ TEST(CmplFinalAnthropicStream, WithTextDiff_EmitsContentBlockEvents) { bool found_start = false, found_delta = false; for (const auto &ev : j) { const std::string e = ev.at("event").get(); - if (e == "content_block_start") found_start = true; - if (e == "content_block_delta") found_delta = true; + if (e == "content_block_start") + found_start = true; + if (e == "content_block_delta") + found_delta = true; } EXPECT_TRUE(found_start); EXPECT_TRUE(found_delta); @@ -1997,11 +2072,11 @@ TEST(CmplFinalAnthropicStream, WithThinkingDiff_EmitsThinkingBlockEvents) { namespace { server_task_result_cmpl_partial make_anthropic_partial(int n_decoded = 1) { server_task_result_cmpl_partial p; - p.is_updated = true; - p.res_type = TASK_RESPONSE_TYPE_ANTHROPIC; - p.n_decoded = n_decoded; - p.n_prompt_tokens = 10; - p.oaicompat_model = "test-model"; + p.is_updated = true; + p.res_type = TASK_RESPONSE_TYPE_ANTHROPIC; + p.n_decoded = n_decoded; + p.n_prompt_tokens = 10; + p.oaicompat_model = "test-model"; p.oaicompat_cmpl_id = "msg-id"; return p; } @@ -2014,7 +2089,7 @@ TEST(CmplPartialAnthropicStream, FirstToken_EmitsMessageStart) { } TEST(CmplPartialAnthropicStream, FirstToken_MessageStart_HasIdModelRole) { - const json j = make_anthropic_partial(1).to_json_anthropic(); + const json j = make_anthropic_partial(1).to_json_anthropic(); const json &msg = j.front().at("data").at("message"); EXPECT_EQ(msg.at("id").get(), "msg-id"); EXPECT_EQ(msg.at("model").get(), "test-model"); @@ -2025,11 +2100,11 @@ TEST(CmplPartialAnthropicStream, FirstToken_MessageStart_HasIdModelRole) { TEST(CmplPartialAnthropicStream, FirstToken_MessageStart_HasUsageCounts) { auto p = make_anthropic_partial(1); - p.n_prompt_tokens = 12; + p.n_prompt_tokens = 12; p.n_prompt_tokens_cache = 4; - const json j = p.to_json_anthropic(); + const json j = p.to_json_anthropic(); const json &usage = j.front().at("data").at("message").at("usage"); - EXPECT_EQ(usage.at("input_tokens").get(), 8); // 12 - 4 + EXPECT_EQ(usage.at("input_tokens").get(), 8); // 12 - 4 EXPECT_EQ(usage.at("cache_read_input_tokens").get(), 4); EXPECT_EQ(usage.at("output_tokens").get(), 0); } @@ -2107,9 +2182,9 @@ TEST(CmplPartialAnthropicStream, WithReasoningFlag_TextBlockIndex_IsOne) { TEST(CmplPartialAnthropicStream, WithToolCallDiff_EmitsToolUseBlockStart) { auto p = make_anthropic_partial(/*n_decoded=*/2); common_chat_msg_diff diff; - diff.tool_call_index = 0; + diff.tool_call_index = 0; diff.tool_call_delta.name = "get_weather"; - diff.tool_call_delta.id = "call_abc"; + diff.tool_call_delta.id = "call_abc"; p.oaicompat_msg_diffs.push_back(diff); const json j = p.to_json_anthropic(); bool found_tool_start = false; @@ -2118,11 +2193,10 @@ TEST(CmplPartialAnthropicStream, WithToolCallDiff_EmitsToolUseBlockStart) { const json &cb = ev.at("data").at("content_block"); if (cb.at("type").get() == "tool_use") { EXPECT_EQ(cb.at("name").get(), "get_weather"); - EXPECT_EQ(cb.at("id").get(), "call_abc"); + EXPECT_EQ(cb.at("id").get(), "call_abc"); found_tool_start = true; } } } EXPECT_TRUE(found_tool_start); } - diff --git a/src/test/cpp/test_utils.cpp b/src/test/cpp/test_utils.cpp index 6a72d391..a2382387 100644 --- a/src/test/cpp/test_utils.cpp +++ b/src/test/cpp/test_utils.cpp @@ -39,17 +39,17 @@ TEST(ServerGrammarTrigger, DefaultConstruct) { TEST(ServerGrammarTrigger, ConstructFromTrigger) { common_grammar_trigger t; - t.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; + t.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; t.value = "tool_call"; server_grammar_trigger sgt(t); - EXPECT_EQ(sgt.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_WORD); + EXPECT_EQ(sgt.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_WORD); EXPECT_EQ(sgt.value.value, "tool_call"); } TEST(ServerGrammarTrigger, WordType_RoundTrip) { common_grammar_trigger t; - t.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; + t.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; t.value = "```json"; json j = server_grammar_trigger(t).to_json(); @@ -58,49 +58,49 @@ TEST(ServerGrammarTrigger, WordType_RoundTrip) { EXPECT_TRUE(j.contains("value")); EXPECT_FALSE(j.contains("token")); // "token" field is TOKEN-type only - EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_WORD)); + EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_WORD)); EXPECT_EQ(j.at("value").get(), "```json"); server_grammar_trigger restored(j); - EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_WORD); + EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_WORD); EXPECT_EQ(restored.value.value, "```json"); } TEST(ServerGrammarTrigger, PatternType_RoundTrip) { common_grammar_trigger t; - t.type = COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN; + t.type = COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN; t.value = "^\\{"; json j = server_grammar_trigger(t).to_json(); - EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN)); + EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN)); EXPECT_EQ(j.at("value").get(), "^\\{"); EXPECT_FALSE(j.contains("token")); server_grammar_trigger restored(j); - EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN); + EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN); EXPECT_EQ(restored.value.value, "^\\{"); } TEST(ServerGrammarTrigger, PatternFullType_RoundTrip) { common_grammar_trigger t; - t.type = COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL; + t.type = COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL; t.value = ".*.*"; json j = server_grammar_trigger(t).to_json(); - EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL)); + EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL)); EXPECT_EQ(j.at("value").get(), ".*.*"); EXPECT_FALSE(j.contains("token")); server_grammar_trigger restored(j); - EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL); + EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL); EXPECT_EQ(restored.value.value, ".*.*"); } TEST(ServerGrammarTrigger, TokenType_IncludesTokenField) { common_grammar_trigger t; - t.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + t.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; t.value = ""; t.token = 12345; @@ -108,17 +108,17 @@ TEST(ServerGrammarTrigger, TokenType_IncludesTokenField) { EXPECT_TRUE(j.contains("token")); // only TOKEN type serialises the token id EXPECT_EQ(j.at("token").get(), 12345); - EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN)); + EXPECT_EQ(j.at("type").get(), static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN)); server_grammar_trigger restored(j); - EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN); + EXPECT_EQ(restored.value.type, COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN); EXPECT_EQ(restored.value.token, 12345); EXPECT_EQ(restored.value.value, ""); } TEST(ServerGrammarTrigger, TypeField_IsIntInJson) { common_grammar_trigger t; - t.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; + t.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; t.value = "x"; json j = server_grammar_trigger(t).to_json(); @@ -131,9 +131,7 @@ TEST(ServerGrammarTrigger, TypeField_IsIntInJson) { // existed; tool call IDs were not generated separately). // ============================================================ -TEST(GenToolCallId, NonEmpty) { - EXPECT_FALSE(gen_tool_call_id().empty()); -} +TEST(GenToolCallId, NonEmpty) { EXPECT_FALSE(gen_tool_call_id().empty()); } TEST(GenToolCallId, Length_Is32) { // random_string() always produces exactly 32 characters @@ -143,8 +141,7 @@ TEST(GenToolCallId, Length_Is32) { TEST(GenToolCallId, ContainsOnlyAlphanumeric) { const std::string id = gen_tool_call_id(); for (char c : id) { - EXPECT_TRUE(std::isalnum(static_cast(c))) - << "Non-alphanumeric character: '" << c << "'"; + EXPECT_TRUE(std::isalnum(static_cast(c))) << "Non-alphanumeric character: '" << c << "'"; } } @@ -176,12 +173,13 @@ json make_rank(int index, double score, int tokens_evaluated = 10) { TEST(FormatResponseRerank, JinaFormat_WrapperStructure) { json request = {{"model", "my-reranker"}}; - json ranks = json::array({make_rank(0, 0.5), make_rank(1, 0.9)}); + json ranks = json::array({make_rank(0, 0.5), make_rank(1, 0.9)}); std::vector texts = {"doc0", "doc1"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, /*is_tei=*/false, texts, /*top_n=*/2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, /*is_tei=*/false, texts, /*top_n=*/2); - EXPECT_EQ(res.at("model").get(), "my-reranker"); + EXPECT_EQ(res.at("model").get(), "my-reranker"); EXPECT_EQ(res.at("object").get(), "list"); EXPECT_TRUE(res.contains("usage")); EXPECT_TRUE(res.contains("results")); @@ -190,10 +188,11 @@ TEST(FormatResponseRerank, JinaFormat_WrapperStructure) { TEST(FormatResponseRerank, JinaFormat_UsesRelevanceScoreLabel) { json request = json::object(); - json ranks = json::array({make_rank(0, 0.7)}); + json ranks = json::array({make_rank(0, 0.7)}); std::vector texts = {"doc"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, false, texts, 1); EXPECT_TRUE(res.at("results")[0].contains("relevance_score")); EXPECT_FALSE(res.at("results")[0].contains("score")); @@ -205,7 +204,8 @@ TEST(FormatResponseRerank, JinaFormat_SortedDescendingByScore) { json ranks = json::array({make_rank(0, 0.3), make_rank(1, 0.9), make_rank(2, 0.1)}); std::vector texts = {"a", "b", "c"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 3); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, false, texts, 3); auto &results = res.at("results"); EXPECT_EQ(results[0].at("index").get(), 1); // highest: 0.9 @@ -215,10 +215,11 @@ TEST(FormatResponseRerank, JinaFormat_SortedDescendingByScore) { TEST(FormatResponseRerank, TopN_LimitsResultCount) { json request = json::object(); - json ranks = json::array({make_rank(0, 0.5), make_rank(1, 0.9), make_rank(2, 0.1)}); + json ranks = json::array({make_rank(0, 0.5), make_rank(1, 0.9), make_rank(2, 0.1)}); std::vector texts = {"a", "b", "c"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, /*top_n=*/1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, false, texts, /*top_n=*/1); EXPECT_EQ(res.at("results").size(), 1u); // The single returned result must be the highest-scoring one @@ -227,11 +228,11 @@ TEST(FormatResponseRerank, TopN_LimitsResultCount) { TEST(FormatResponseRerank, TopN_Two_KeepsTopTwo) { json request = json::object(); - json ranks = json::array({ - make_rank(0, 0.1), make_rank(1, 0.9), make_rank(2, 0.5), make_rank(3, 0.7)}); + json ranks = json::array({make_rank(0, 0.1), make_rank(1, 0.9), make_rank(2, 0.5), make_rank(3, 0.7)}); std::vector texts = {"a", "b", "c", "d"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, false, texts, 2); EXPECT_EQ(res.at("results").size(), 2u); EXPECT_EQ(res.at("results")[0].at("index").get(), 1); // 0.9 @@ -240,10 +241,11 @@ TEST(FormatResponseRerank, TopN_Two_KeepsTopTwo) { TEST(FormatResponseRerank, TopN_LargerThanCount_ReturnsAll) { json request = json::object(); - json ranks = json::array({make_rank(0, 0.8), make_rank(1, 0.2)}); + json ranks = json::array({make_rank(0, 0.8), make_rank(1, 0.2)}); std::vector texts = {"x", "y"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, /*top_n=*/100); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, false, texts, /*top_n=*/100); EXPECT_EQ(res.at("results").size(), 2u); } @@ -251,10 +253,11 @@ TEST(FormatResponseRerank, TopN_LargerThanCount_ReturnsAll) { TEST(FormatResponseRerank, TopN_Zero_ReturnsEmptyResults) { // top_n=0 must truncate to zero elements, not crash or return all json request = json::object(); - json ranks = json::array({make_rank(0, 0.9), make_rank(1, 0.5)}); + json ranks = json::array({make_rank(0, 0.9), make_rank(1, 0.5)}); std::vector texts = {"a", "b"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, /*top_n=*/0); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, false, texts, /*top_n=*/0); ASSERT_TRUE(res.at("results").is_array()); EXPECT_TRUE(res.at("results").empty()); @@ -262,21 +265,23 @@ TEST(FormatResponseRerank, TopN_Zero_ReturnsEmptyResults) { TEST(FormatResponseRerank, TokenCounting_Accumulated) { json request = json::object(); - json ranks = json::array({make_rank(0, 0.5, 15), make_rank(1, 0.9, 25)}); + json ranks = json::array({make_rank(0, 0.5, 15), make_rank(1, 0.9, 25)}); std::vector texts = {"a", "b"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, false, texts, 2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, false, texts, 2); EXPECT_EQ(res.at("usage").at("prompt_tokens").get(), 40); // 15 + 25 - EXPECT_EQ(res.at("usage").at("total_tokens").get(), 40); + EXPECT_EQ(res.at("usage").at("total_tokens").get(), 40); } TEST(FormatResponseRerank, TeiFormat_ReturnsArrayDirectly) { json request = json::object(); - json ranks = json::array({make_rank(0, 0.8), make_rank(1, 0.3)}); + json ranks = json::array({make_rank(0, 0.8), make_rank(1, 0.3)}); std::vector texts = {"x", "y"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, /*is_tei=*/true, texts, 2); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, /*is_tei=*/true, texts, 2); EXPECT_TRUE(res.is_array()); // no outer wrapper object EXPECT_EQ(res.size(), 2u); @@ -284,10 +289,11 @@ TEST(FormatResponseRerank, TeiFormat_ReturnsArrayDirectly) { TEST(FormatResponseRerank, TeiFormat_UsesScoreLabel) { json request = json::object(); - json ranks = json::array({make_rank(0, 0.8)}); + json ranks = json::array({make_rank(0, 0.8)}); std::vector texts = {"doc"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, true, texts, 1); ASSERT_TRUE(res.is_array()); EXPECT_TRUE(res[0].contains("score")); @@ -296,10 +302,11 @@ TEST(FormatResponseRerank, TeiFormat_UsesScoreLabel) { TEST(FormatResponseRerank, TeiFormat_ReturnText_IncludesDocumentText) { json request = {{"return_text", true}}; - json ranks = json::array({make_rank(0, 0.9)}); + json ranks = json::array({make_rank(0, 0.9)}); std::vector texts = {"my document content"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, true, texts, 1); ASSERT_TRUE(res.is_array()); EXPECT_TRUE(res[0].contains("text")); @@ -308,10 +315,11 @@ TEST(FormatResponseRerank, TeiFormat_ReturnText_IncludesDocumentText) { TEST(FormatResponseRerank, TeiFormat_NoReturnText_NoTextField) { json request = {{"return_text", false}}; - json ranks = json::array({make_rank(0, 0.9)}); + json ranks = json::array({make_rank(0, 0.9)}); std::vector texts = {"doc"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 1); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, true, texts, 1); ASSERT_TRUE(res.is_array()); EXPECT_FALSE(res[0].contains("text")); @@ -319,10 +327,11 @@ TEST(FormatResponseRerank, TeiFormat_NoReturnText_NoTextField) { TEST(FormatResponseRerank, TeiFormat_SortedDescendingByScore) { json request = json::object(); - json ranks = json::array({make_rank(0, 0.1), make_rank(1, 0.9), make_rank(2, 0.5)}); + json ranks = json::array({make_rank(0, 0.1), make_rank(1, 0.9), make_rank(2, 0.5)}); std::vector texts = {"a", "b", "c"}; - json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), ranks, true, texts, 3); + json res = format_response_rerank(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), + ranks, true, texts, 3); ASSERT_TRUE(res.is_array()); EXPECT_EQ(res[0].at("index").get(), 1); // 0.9 @@ -541,18 +550,14 @@ TEST(ServerTokens, MoveAssign_TransfersOwnership) { TEST(ServerTokens, CopyIsDeleted) { // Compile-time assertion: copying must be disabled to prevent // accidental shallow copies of the chunk map. - static_assert(!std::is_copy_constructible::value, - "server_tokens must not be copy-constructible"); - static_assert(!std::is_copy_assignable::value, - "server_tokens must not be copy-assignable"); + static_assert(!std::is_copy_constructible::value, "server_tokens must not be copy-constructible"); + static_assert(!std::is_copy_assignable::value, "server_tokens must not be copy-assignable"); SUCCEED(); } TEST(ServerTokens, MoveIsAllowed) { - static_assert(std::is_move_constructible::value, - "server_tokens must be move-constructible"); - static_assert(std::is_move_assignable::value, - "server_tokens must be move-assignable"); + static_assert(std::is_move_constructible::value, "server_tokens must be move-constructible"); + static_assert(std::is_move_assignable::value, "server_tokens must be move-assignable"); SUCCEED(); } @@ -639,17 +644,11 @@ TEST(JsonValue, BoolValue) { // json_is_array_of_numbers / json_is_array_of_mixed // ============================================================ -TEST(JsonArrayChecks, ArrayOfIntegers_IsNumbers) { - EXPECT_TRUE(json_is_array_of_numbers(json{1, 2, 3})); -} +TEST(JsonArrayChecks, ArrayOfIntegers_IsNumbers) { EXPECT_TRUE(json_is_array_of_numbers(json{1, 2, 3})); } -TEST(JsonArrayChecks, EmptyArray_IsNumbers) { - EXPECT_TRUE(json_is_array_of_numbers(json::array())); -} +TEST(JsonArrayChecks, EmptyArray_IsNumbers) { EXPECT_TRUE(json_is_array_of_numbers(json::array())); } -TEST(JsonArrayChecks, ArrayWithString_NotNumbers) { - EXPECT_FALSE(json_is_array_of_numbers(json{1, "hello", 3})); -} +TEST(JsonArrayChecks, ArrayWithString_NotNumbers) { EXPECT_FALSE(json_is_array_of_numbers(json{1, "hello", 3})); } TEST(JsonArrayChecks, NonArray_NotNumbers) { EXPECT_FALSE(json_is_array_of_numbers(json("just a string"))); @@ -660,17 +659,11 @@ TEST(JsonArrayChecks, MixedNumbersAndStrings_IsMixed) { EXPECT_TRUE(json_is_array_of_mixed_numbers_strings(json{1, "hello", 3})); } -TEST(JsonArrayChecks, OnlyNumbers_NotMixed) { - EXPECT_FALSE(json_is_array_of_mixed_numbers_strings(json{1, 2, 3})); -} +TEST(JsonArrayChecks, OnlyNumbers_NotMixed) { EXPECT_FALSE(json_is_array_of_mixed_numbers_strings(json{1, 2, 3})); } -TEST(JsonArrayChecks, OnlyStrings_NotMixed) { - EXPECT_FALSE(json_is_array_of_mixed_numbers_strings(json{"a", "b"})); -} +TEST(JsonArrayChecks, OnlyStrings_NotMixed) { EXPECT_FALSE(json_is_array_of_mixed_numbers_strings(json{"a", "b"})); } -TEST(JsonArrayChecks, EmptyArray_NotMixed) { - EXPECT_FALSE(json_is_array_of_mixed_numbers_strings(json::array())); -} +TEST(JsonArrayChecks, EmptyArray_NotMixed) { EXPECT_FALSE(json_is_array_of_mixed_numbers_strings(json::array())); } // json_is_array_and_contains_numbers // Returns true when the input is an array that has at least one integer @@ -689,9 +682,7 @@ TEST(JsonArrayChecks, EmptyArray_NotContainsNumbers) { EXPECT_FALSE(json_is_array_and_contains_numbers(json::array())); } -TEST(JsonArrayChecks, NonArray_NotContainsNumbers) { - EXPECT_FALSE(json_is_array_and_contains_numbers(json(42))); -} +TEST(JsonArrayChecks, NonArray_NotContainsNumbers) { EXPECT_FALSE(json_is_array_and_contains_numbers(json(42))); } // ============================================================ // validate_utf8 — pure logic, no llama.cpp deps @@ -702,9 +693,7 @@ TEST(ValidateUtf8, AsciiOnly_ReturnsFullLength) { EXPECT_EQ(validate_utf8(s), s.size()); } -TEST(ValidateUtf8, EmptyString_ReturnsZero) { - EXPECT_EQ(validate_utf8(""), 0u); -} +TEST(ValidateUtf8, EmptyString_ReturnsZero) { EXPECT_EQ(validate_utf8(""), 0u); } TEST(ValidateUtf8, ValidTwoByteSequence_FullLength) { // "é" = 0xC3 0xA9 @@ -746,13 +735,9 @@ TEST(ValidateUtf8, MixedAsciiAndMultiByte_ReturnsFullLength) { // is_valid_utf8 — pure logic, no llama.cpp deps // ============================================================ -TEST(IsValidUtf8, PlainAscii_Valid) { - EXPECT_TRUE(is_valid_utf8("Hello, World!")); -} +TEST(IsValidUtf8, PlainAscii_Valid) { EXPECT_TRUE(is_valid_utf8("Hello, World!")); } -TEST(IsValidUtf8, EmptyString_Valid) { - EXPECT_TRUE(is_valid_utf8("")); -} +TEST(IsValidUtf8, EmptyString_Valid) { EXPECT_TRUE(is_valid_utf8("")); } TEST(IsValidUtf8, TwoByteChar_Valid) { EXPECT_TRUE(is_valid_utf8("\xC3\xA9")); // é @@ -767,9 +752,7 @@ TEST(IsValidUtf8, FourByteChar_Valid) { EXPECT_TRUE(is_valid_utf8("\xF0\x9F\x98\x80")); } -TEST(IsValidUtf8, InvalidLeadByte_Invalid) { - EXPECT_FALSE(is_valid_utf8("\xFF\xFF")); -} +TEST(IsValidUtf8, InvalidLeadByte_Invalid) { EXPECT_FALSE(is_valid_utf8("\xFF\xFF")); } TEST(IsValidUtf8, TruncatedTwoByte_Invalid) { EXPECT_FALSE(is_valid_utf8("\xC3")); // missing continuation byte @@ -911,7 +894,8 @@ json make_embedding_elem(const std::vector &vec, int tokens = 4) { TEST(FormatEmbeddingsResponse, SingleEmbedding_Fields) { const json request = {{"model", "test-model"}}; const json embeddings = json::array({make_embedding_elem({0.1f, 0.2f, 0.3f})}); - const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); + const json res = format_embeddings_response_oaicompat( + request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); EXPECT_EQ(res.at("object").get(), "list"); EXPECT_EQ(res.at("model").get(), "test-model"); EXPECT_EQ(res.at("data").size(), 1u); @@ -922,7 +906,8 @@ TEST(FormatEmbeddingsResponse, SingleEmbedding_Fields) { TEST(FormatEmbeddingsResponse, TokensAccumulated) { const json request = {}; const json embeddings = json::array({make_embedding_elem({1.0f}, 3), make_embedding_elem({2.0f}, 7)}); - const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); + const json res = format_embeddings_response_oaicompat( + request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); EXPECT_EQ(res.at("usage").at("prompt_tokens").get(), 10); EXPECT_EQ(res.at("usage").at("total_tokens").get(), 10); } @@ -934,7 +919,8 @@ TEST(FormatEmbeddingsResponse, MultipleEmbeddings_IndicesIncrement) { make_embedding_elem({0.2f}), make_embedding_elem({0.3f}), }); - const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); + const json res = format_embeddings_response_oaicompat( + request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings); EXPECT_EQ(res.at("data").size(), 3u); EXPECT_EQ(res.at("data")[0].at("index").get(), 0); EXPECT_EQ(res.at("data")[1].at("index").get(), 1); @@ -944,7 +930,8 @@ TEST(FormatEmbeddingsResponse, MultipleEmbeddings_IndicesIncrement) { TEST(FormatEmbeddingsResponse, Base64Format_EncodingFormatField) { const json request = {}; const json embeddings = json::array({make_embedding_elem({1.0f, 0.0f})}); - const json res = format_embeddings_response_oaicompat(request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings, /*use_base64=*/true); + const json res = format_embeddings_response_oaicompat( + request, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL)), embeddings, /*use_base64=*/true); const json &elem = res.at("data")[0]; EXPECT_TRUE(elem.contains("encoding_format")); EXPECT_EQ(elem.at("encoding_format").get(), "base64"); @@ -989,9 +976,7 @@ TEST(SafeJsonToStr, NormalJson_ProducesCompactString) { EXPECT_EQ(s.find('\n'), std::string::npos); } -TEST(SafeJsonToStr, EmptyObject_ProducesEmptyBraces) { - EXPECT_EQ(safe_json_to_str(json::object()), "{}"); -} +TEST(SafeJsonToStr, EmptyObject_ProducesEmptyBraces) { EXPECT_EQ(safe_json_to_str(json::object()), "{}"); } TEST(SafeJsonToStr, ArrayValue_Roundtrips) { const json j = json::array({1, 2, 3}); @@ -1014,9 +999,8 @@ namespace { // Minimal helper: build body + options + out_files for early-throw tests std::vector g_out_files; -json make_chat_body_with_messages(const json &messages_override = json::array({ - {{"role", "user"}, {"content", "hello"}} -})) { +json make_chat_body_with_messages(const json &messages_override = json::array({{{"role", "user"}, + {"content", "hello"}}})) { return json{{"messages", messages_override}}; } @@ -1059,64 +1043,51 @@ TEST(OaicompatChatParams, AssistantMissingBothContentAndToolCalls_Throws) { } TEST(OaicompatChatParams, ToolsWithoutJinja_Throws) { - json body = { - {"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, - {"tools", json::array({{{"type", "function"}}})} - }; + json body = {{"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, + {"tools", json::array({{{"type", "function"}}})}}; server_chat_params opt = make_no_jinja_opts(); std::vector files; EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, NonAutoToolChoiceWithoutJinja_Throws) { - json body = { - {"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, - {"tool_choice", "none"} - }; + json body = {{"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, {"tool_choice", "none"}}; server_chat_params opt = make_no_jinja_opts(); std::vector files; EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, GrammarAndJsonSchema_Throws) { - json body = { - {"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, - {"grammar", "root ::= [a-z]+"}, - {"json_schema", {{"type", "object"}}} - }; + json body = {{"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, + {"grammar", "root ::= [a-z]+"}, + {"json_schema", {{"type", "object"}}}}; server_chat_params opt = make_no_jinja_opts(); std::vector files; EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, InvalidResponseFormatType_Throws) { - json body = { - {"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, - {"response_format", {{"type", "invalid_type"}}} - }; + json body = {{"messages", json::array({{{"role", "user"}, {"content", "hi"}}})}, + {"response_format", {{"type", "invalid_type"}}}}; server_chat_params opt = make_no_jinja_opts(); std::vector files; EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, ContentPartTypeUnsupported_Throws) { - json body = {{"messages", json::array({{ - {"role", "user"}, - {"content", json::array({{{"type", "video_url"}, {"url", "x"}}})} - }})}}; + json body = {{"messages", json::array({{{"role", "user"}, + {"content", json::array({{{"type", "video_url"}, {"url", "x"}}})}}})}}; server_chat_params opt = make_no_jinja_opts(); std::vector files; EXPECT_THROW(oaicompat_chat_params_parse(body, opt, files), std::exception); } TEST(OaicompatChatParams, ImageUrlWithoutAllowImage_Throws) { - json body = {{"messages", json::array({{ - {"role", "user"}, - {"content", json::array({{ - {"type", "image_url"}, - {"image_url", {{"url", "data:image/png;base64,abc"}}} - }})} - }})}}; + json body = { + {"messages", + json::array({{{"role", "user"}, + {"content", json::array({{{"type", "image_url"}, + {"image_url", {{"url", "data:image/png;base64,abc"}}}}})}}})}}; server_chat_params opt = make_no_jinja_opts(); opt.allow_image = false; std::vector files; @@ -1140,26 +1111,18 @@ namespace { common_adapter_lora_info make_lora(float scale, struct llama_adapter_lora *ptr = nullptr) { common_adapter_lora_info info; info.scale = scale; - info.ptr = ptr; + info.ptr = ptr; return info; } } // namespace -TEST(AreLoraEqual, BothEmpty_AreEqual) { - EXPECT_TRUE(are_lora_equal({}, {})); -} +TEST(AreLoraEqual, BothEmpty_AreEqual) { EXPECT_TRUE(are_lora_equal({}, {})); } -TEST(AreLoraEqual, DifferentSizes_NotEqual) { - EXPECT_FALSE(are_lora_equal({make_lora(1.0f)}, {})); -} +TEST(AreLoraEqual, DifferentSizes_NotEqual) { EXPECT_FALSE(are_lora_equal({make_lora(1.0f)}, {})); } -TEST(AreLoraEqual, SameScaleNullPtr_AreEqual) { - EXPECT_TRUE(are_lora_equal({make_lora(0.5f)}, {make_lora(0.5f)})); -} +TEST(AreLoraEqual, SameScaleNullPtr_AreEqual) { EXPECT_TRUE(are_lora_equal({make_lora(0.5f)}, {make_lora(0.5f)})); } -TEST(AreLoraEqual, DifferentScale_NotEqual) { - EXPECT_FALSE(are_lora_equal({make_lora(0.5f)}, {make_lora(1.0f)})); -} +TEST(AreLoraEqual, DifferentScale_NotEqual) { EXPECT_FALSE(are_lora_equal({make_lora(0.5f)}, {make_lora(1.0f)})); } TEST(AreLoraEqual, DifferentPtr_NotEqual) { int dummy = 0; @@ -1347,7 +1310,7 @@ TEST(TokenPieceValue, ValidThreeByteChar_ReturnsString) { TEST(FormatOaiSse, SingleObject_ProducesOneLine) { const json j = {{"content", "hello"}}; const std::string s = format_oai_sse(j); - EXPECT_EQ(s.rfind("data: ", 0), 0u); // starts with "data: " + EXPECT_EQ(s.rfind("data: ", 0), 0u); // starts with "data: " EXPECT_NE(s.find("\"content\""), std::string::npos); EXPECT_EQ(s.substr(s.size() - 2), "\n\n"); } @@ -1358,13 +1321,14 @@ TEST(FormatOaiSse, Array_ProducesMultipleEvents) { // Each element generates one "data: ... \n\n" size_t count = 0; size_t pos = 0; - while ((pos = s.find("data: ", pos)) != std::string::npos) { ++count; ++pos; } + while ((pos = s.find("data: ", pos)) != std::string::npos) { + ++count; + ++pos; + } EXPECT_EQ(count, 2u); } -TEST(FormatOaiSse, StringValue_DoesNotThrow) { - EXPECT_NO_THROW(format_oai_sse(json("done"))); -} +TEST(FormatOaiSse, StringValue_DoesNotThrow) { EXPECT_NO_THROW(format_oai_sse(json("done"))); } // ============================================================ // format_oai_resp_sse @@ -1381,10 +1345,8 @@ TEST(FormatOaiRespSse, SingleEvent_HasEventAndDataLines) { } TEST(FormatOaiRespSse, Array_ProducesMultipleEventBlocks) { - const json arr = json::array({ - {{"event", "e1"}, {"data", json::object()}}, - {{"event", "e2"}, {"data", json::object()}} - }); + const json arr = + json::array({{{"event", "e1"}, {"data", json::object()}}, {{"event", "e2"}, {"data", json::object()}}}); const std::string s = format_oai_resp_sse(arr); EXPECT_NE(s.find("event: e1"), std::string::npos); EXPECT_NE(s.find("event: e2"), std::string::npos); @@ -1412,10 +1374,7 @@ TEST(FormatAnthropicSse, WithoutEventField_BareLine) { } TEST(FormatAnthropicSse, Array_EachElementDispatchedCorrectly) { - const json arr = json::array({ - {{"event", "ping"}, {"data", json::object()}}, - {{"type", "bare"}} - }); + const json arr = json::array({{{"event", "ping"}, {"data", json::object()}}, {{"type", "bare"}}}); const std::string s = format_anthropic_sse(arr); EXPECT_NE(s.find("event: ping"), std::string::npos); // second element is bare diff --git a/src/test/java/examples/OpenAiServerExample.java b/src/test/java/examples/OpenAiServerExample.java new file mode 100644 index 00000000..f1e3c802 --- /dev/null +++ b/src/test/java/examples/OpenAiServerExample.java @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// SPDX-FileCopyrightText: 2023-2025 Konstantin Herud +// +// SPDX-License-Identifier: MIT + +package examples; + +import java.io.IOException; +import net.ladenthin.llama.LlamaModel; +import net.ladenthin.llama.parameters.ModelParameters; +import net.ladenthin.llama.server.OpenAiCompatServer; +import net.ladenthin.llama.server.OpenAiServerConfig; +import org.junit.jupiter.api.Disabled; + +// Runnable demo (no @Test): starts a local OpenAI-compatible HTTP endpoint over a GGUF model so an +// editor such as VS Code Copilot (Custom Endpoint) can drive it. Point the model path at a local +// GGUF via -Dnet.ladenthin.llama.server.model=... ; @Disabled keeps it out of `mvn test`. +@Disabled +public class OpenAiServerExample { + + public static void main(String... args) throws IOException, InterruptedException { + String modelPath = System.getProperty("net.ladenthin.llama.server.model", "models/codellama-7b.Q2_K.gguf"); + int port = Integer.getInteger("net.ladenthin.llama.server.port", 8080); + + // Two parallel slots let the editor's chat and its background title/summary requests run + // concurrently instead of serializing behind one another. + ModelParameters modelParams = + new ModelParameters().setModel(modelPath).setCtxSize(8192).setParallel(2); + + OpenAiServerConfig config = OpenAiServerConfig.builder() + .port(port) + .modelId("local-model") + .maxInputTokens(6144) + .maxOutputTokens(2048) + .build(); + + try (LlamaModel model = new LlamaModel(modelParams); + OpenAiCompatServer server = new OpenAiCompatServer(model, config).start()) { + String url = "http://127.0.0.1:" + server.getPort() + OpenAiCompatServer.PATH_CHAT_COMPLETIONS; + System.out.println("OpenAI-compatible endpoint ready: " + url); + System.out.println("In VS Code: Chat: Manage Language Models -> Add Models -> Custom Endpoint ->"); + System.out.println(" API type 'Chat Completions', then set the model 'url' to: " + url); + System.out.println("Press Ctrl+C to stop."); + Thread.currentThread().join(); + } + } +} diff --git a/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java b/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java index 667b1be5..7e31d74b 100644 --- a/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java +++ b/src/test/java/net/ladenthin/llama/LlamaArchitectureTest.java @@ -3,11 +3,14 @@ // SPDX-License-Identifier: MIT package net.ladenthin.llama; +import static com.tngtech.archunit.core.domain.JavaClass.Predicates.resideInAPackage; +import static com.tngtech.archunit.core.domain.JavaClass.Predicates.resideInAnyPackage; import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.fields; import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.noClasses; import static com.tngtech.archunit.library.Architectures.layeredArchitecture; import static com.tngtech.archunit.library.dependencies.SlicesRuleDefinition.slices; +import com.tngtech.archunit.base.DescribedPredicate; import com.tngtech.archunit.core.importer.ImportOption; import com.tngtech.archunit.junit.AnalyzeClasses; import com.tngtech.archunit.junit.ArchTest; @@ -139,14 +142,19 @@ public class LlamaArchitectureTest { * These are not part of the Java SE API and may change or disappear without notice. * {@code OSInfo} is vendored from xerial/sqlite-jdbc and was already audited; * if it ever pulls in sun.*, this rule fails and forces a re-audit. + * + *

Exception: {@code com.sun.net.httpserver} is a supported, documented JDK API + * (the exported {@code jdk.httpserver} module), used by {@code net.ladenthin.llama.server} to + * provide the OpenAI-compatible endpoint without adding a web-framework dependency. Despite the + * {@code com.sun} prefix it is not an internal package, so it is allowed. */ @ArchTest static final ArchRule noInternalJdkImports = noClasses() .that() .resideInAPackage("net.ladenthin.llama..") .should() - .dependOnClassesThat() - .resideInAnyPackage("sun..", "com.sun..", "jdk.internal.."); + .dependOnClassesThat(resideInAnyPackage("sun..", "com.sun..", "jdk.internal..") + .and(DescribedPredicate.not(resideInAPackage("com.sun.net.httpserver..")))); /** * Public mutable state forbidden: any non-static field declared diff --git a/src/test/java/net/ladenthin/llama/json/ChatStreamChunkParserTest.java b/src/test/java/net/ladenthin/llama/json/ChatStreamChunkParserTest.java new file mode 100644 index 00000000..d888da70 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/json/ChatStreamChunkParserTest.java @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// SPDX-FileCopyrightText: 2023-2025 Konstantin Herud +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.json; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link ChatStreamChunkParser}. + * No JVM native library or model file needed — JSON string literals only. + */ +public class ChatStreamChunkParserTest { + + private final ChatStreamChunkParser parser = new ChatStreamChunkParser(); + + @Test + public void feed_objectData_emitsOneChunk_andReportsNotStopped() { + List chunks = new ArrayList<>(); + boolean stop = parser.feed("{\"data\":{\"object\":\"chat.completion.chunk\"},\"stop\":false}", chunks::add); + assertThat(stop, is(false)); + assertThat(chunks, hasSize(1)); + assertThat(chunks.get(0).contains("chat.completion.chunk"), is(true)); + } + + @Test + public void feed_arrayData_emitsEachElementInOrder_andReportsStopped() { + List chunks = new ArrayList<>(); + boolean stop = parser.feed("{\"data\":[{\"i\":1},{\"i\":2}],\"stop\":true}", chunks::add); + assertThat(stop, is(true)); + assertThat(chunks, hasSize(2)); + assertThat(chunks.get(0).contains("\"i\":1"), is(true)); + assertThat(chunks.get(1).contains("\"i\":2"), is(true)); + } + + @Test + public void feed_missingData_emitsNothing() { + List chunks = new ArrayList<>(); + boolean stop = parser.feed("{\"stop\":true}", chunks::add); + assertThat(stop, is(true)); + assertThat(chunks, is(empty())); + } + + @Test + public void feed_stopDefaultsFalse_whenAbsent() { + List chunks = new ArrayList<>(); + boolean stop = parser.feed("{\"data\":{\"x\":1}}", chunks::add); + assertThat(stop, is(false)); + assertThat(chunks, hasSize(1)); + } + + @Test + public void feed_malformedEnvelope_reportsStop_andEmitsNothing() { + List chunks = new ArrayList<>(); + boolean stop = parser.feed("this is not json", chunks::add); + assertThat(stop, is(true)); + assertThat(chunks, is(empty())); + } + + @Test + public void feed_nullData_emitsNothing() { + List chunks = new ArrayList<>(); + boolean stop = parser.feed("{\"data\":null,\"stop\":false}", chunks::add); + assertThat(stop, is(false)); + assertThat(chunks, is(empty())); + } +} diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerHttpTest.java b/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerHttpTest.java new file mode 100644 index 00000000..77c41b14 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerHttpTest.java @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +/** + * End-to-end HTTP tests for {@link OpenAiCompatServer} driven over a real socket with a + * {@link FakeChatBackend} — no native library and no model are loaded. Exercises routing, + * authentication, the non-streaming and Server-Sent-Events paths, heartbeats, and error statuses. + * + *

HTTP request plumbing is inherited from {@link OpenAiServerTestSupport}. + */ +public class OpenAiCompatServerHttpTest extends OpenAiServerTestSupport { + + private static final String CHAT_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; + + private static OpenAiServerConfig config() { + return OpenAiServerConfig.builder() + .host("127.0.0.1") + .port(0) + .modelId("test-model") + .build(); + } + + @Test + public void nonStreamingReturnsTheCompletionBody() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), config()).start()) { + Response response = post(server.getPort(), "/v1/chat/completions", CHAT_BODY, ""); + assertThat(response.code, is(200)); + assertThat(response.body, containsString("chat.completion")); + assertThat(response.body, containsString("hello")); + } + } + + @Test + public void streamingReturnsSseChunksThenDone() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), config()).start()) { + String body = "{\"stream\":true,\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; + Response response = post(server.getPort(), "/v1/chat/completions", body, ""); + assertThat(response.code, is(200)); + assertThat(response.body, containsString("data: ")); + assertThat(response.body, containsString("chat.completion.chunk")); + assertThat(response.body, containsString("data: [DONE]")); + } + } + + @Test + public void streamingEmitsHeartbeatsDuringAGap() throws IOException { + OpenAiServerConfig cfg = OpenAiServerConfig.builder() + .host("127.0.0.1") + .port(0) + .heartbeatMillis(50L) + .build(); + try (OpenAiCompatServer server = new OpenAiCompatServer(new SlowFakeChatBackend(), cfg).start()) { + String body = "{\"stream\":true,\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; + Response response = post(server.getPort(), "/v1/chat/completions", body, ""); + assertThat(response.code, is(200)); + assertThat(response.body, containsString(":")); // SSE comment heartbeat + assertThat(response.body, containsString("data: [DONE]")); + } + } + + @Test + public void modelsEndpointAdvertisesConfiguredModel() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), config()).start()) { + Response response = get(server.getPort(), "/v1/models", ""); + assertThat(response.code, is(200)); + assertThat(response.body, containsString("test-model")); + } + } + + @Test + public void unknownPathReturns404() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), config()).start()) { + Response response = get(server.getPort(), "/v1/embeddings", ""); + assertThat(response.code, is(404)); + } + } + + @Test + public void missingMessagesReturns400() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), config()).start()) { + Response response = post(server.getPort(), "/v1/chat/completions", "{}", ""); + assertThat(response.code, is(400)); + } + } + + @Test + public void malformedJsonReturns400() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), config()).start()) { + Response response = post(server.getPort(), "/v1/chat/completions", "not json", ""); + assertThat(response.code, is(400)); + } + } + + @Test + public void getOnChatCompletionsReturns405() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), config()).start()) { + Response response = get(server.getPort(), "/v1/chat/completions", ""); + assertThat(response.code, is(405)); + } + } + + @Test + public void authRequiredWhenApiKeyConfigured() throws IOException { + OpenAiServerConfig cfg = OpenAiServerConfig.builder() + .host("127.0.0.1") + .port(0) + .apiKey("secret") + .build(); + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeChatBackend(), cfg).start()) { + int port = server.getPort(); + assertThat(post(port, "/v1/chat/completions", CHAT_BODY, "").code, is(401)); + assertThat(post(port, "/v1/chat/completions", CHAT_BODY, "Bearer wrong").code, is(401)); + assertThat(post(port, "/v1/chat/completions", CHAT_BODY, "Bearer secret").code, is(200)); + } + } + + /** Deterministic backend that returns canned OpenAI shapes. */ + static final class FakeChatBackend implements ChatBackend { + @Override + public String complete(JsonNode request) { + return "{\"object\":\"chat.completion\",\"choices\":[{\"index\":0," + + "\"message\":{\"role\":\"assistant\",\"content\":\"hello\"}}]}"; + } + + @Override + public void stream(JsonNode request, ChunkSink sink) throws IOException { + sink.accept("{\"object\":\"chat.completion.chunk\",\"choices\":[{\"delta\":{\"content\":\"he\"}}]}"); + sink.accept("{\"object\":\"chat.completion.chunk\"," + + "\"choices\":[{\"delta\":{\"content\":\"llo\"},\"finish_reason\":\"stop\"}]}"); + } + } + + /** Backend that stalls before emitting, so the server's heartbeat fires during the gap. */ + static final class SlowFakeChatBackend implements ChatBackend { + @Override + public String complete(JsonNode request) { + return "{\"object\":\"chat.completion\",\"choices\":[]}"; + } + + @Override + public void stream(JsonNode request, ChunkSink sink) throws IOException { + try { + Thread.sleep(300L); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + sink.accept("{\"object\":\"chat.completion.chunk\"," + + "\"choices\":[{\"delta\":{\"content\":\"done\"},\"finish_reason\":\"stop\"}]}"); + } + } +} diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerIntegrationTest.java b/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerIntegrationTest.java new file mode 100644 index 00000000..c460f662 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerIntegrationTest.java @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.is; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.File; +import java.io.IOException; +import net.ladenthin.llama.LlamaModel; +import net.ladenthin.llama.TestConstants; +import net.ladenthin.llama.parameters.ModelParameters; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/** + * End-to-end integration test for {@link OpenAiCompatServer} against a real model served over a real + * socket. Reuses the Qwen3-0.6B GGUF that the CI pipeline already downloads as the reasoning model + * ({@link TestConstants#REASONING_MODEL_PATH}); it is instruct-tuned (has a chat template) and one of + * llama.cpp's better tool-calling families, so no extra download is needed. Self-skips when the model + * file is absent (e.g. a local checkout without models), so it never breaks a model-free run. + * + *

Assertions are deliberately structural (valid OpenAI shapes, stream terminates) rather than + * content-specific — a 0.6B model's exact wording and whether it elects to call a tool are not + * deterministic. The deterministic chunk/tool-call plumbing is covered by + * {@link OpenAiCompatServerHttpTest} with a fake backend. HTTP request plumbing is inherited from + * {@link OpenAiServerTestSupport}. + */ +public class OpenAiCompatServerIntegrationTest extends OpenAiServerTestSupport { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String MODEL_ID = "qwen3-local"; + + private static LlamaModel model; + private static OpenAiCompatServer server; + private static int port; + + @BeforeAll + public static void setup() throws IOException { + Assumptions.assumeTrue( + new File(TestConstants.REASONING_MODEL_PATH).exists(), + "Reasoning model (Qwen3-0.6B) not found, skipping OpenAI server integration test"); + int gpuLayers = Integer.getInteger(TestConstants.PROP_TEST_NGL, TestConstants.DEFAULT_TEST_NGL); + model = new LlamaModel(new ModelParameters() + .setModel(TestConstants.REASONING_MODEL_PATH) + .setCtxSize(1024) + .setGpuLayers(gpuLayers) + .setFit(false) + .setParallel(2)); + server = new OpenAiCompatServer( + model, + OpenAiServerConfig.builder().port(0).modelId(MODEL_ID).build()) + .start(); + port = server.getPort(); + } + + @AfterAll + public static void tearDown() { + if (server != null) { + server.close(); + } + if (model != null) { + model.close(); + } + } + + @Test + public void nonStreamingChatReturnsValidCompletion() throws IOException { + String body = "{\"model\":\"" + MODEL_ID + "\",\"max_tokens\":16," + + "\"messages\":[{\"role\":\"user\",\"content\":\"Say hello in one word.\"}]}"; + Response response = post(port, "/v1/chat/completions", body, ""); + assertThat(response.code, is(200)); + JsonNode json = MAPPER.readTree(response.body); + assertThat(json.path("object").asText(), is("chat.completion")); + assertThat(json.path("choices").size(), greaterThan(0)); + assertThat(json.path("choices").get(0).path("message").path("role").asText(), is("assistant")); + } + + @Test + public void streamingChatEmitsChunksAndDone() throws IOException { + String body = "{\"model\":\"" + MODEL_ID + "\",\"stream\":true,\"max_tokens\":16," + + "\"messages\":[{\"role\":\"user\",\"content\":\"Say hello in one word.\"}]}"; + Response response = post(port, "/v1/chat/completions", body, ""); + assertThat(response.code, is(200)); + assertThat(response.body, containsString("chat.completion.chunk")); + assertThat(response.body, containsString("data: [DONE]")); + } + + @Test + public void toolRequestRoundTripsThroughTheJinjaPath() throws IOException { + // Forwards an OpenAI tools array; the mapper enables use_jinja so the native parser applies + // Qwen3's tool-aware template. We assert the request is accepted and returns a structurally + // valid OpenAI message (content and/or tool_calls) — not that this tiny model elects to call. + String body = "{\"model\":\"" + MODEL_ID + "\",\"max_tokens\":48," + + "\"messages\":[{\"role\":\"user\",\"content\":\"What is the weather in Paris?\"}]," + + "\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\"," + + "\"description\":\"Get the weather for a city\",\"parameters\":{\"type\":\"object\"," + + "\"properties\":{\"city\":{\"type\":\"string\"}},\"required\":[\"city\"]}}}]}"; + Response response = post(port, "/v1/chat/completions", body, ""); + assertThat(response.code, is(200)); + JsonNode message = MAPPER.readTree(response.body).path("choices").get(0).path("message"); + assertThat(message.isObject(), is(true)); + } + + @Test + public void modelsEndpointAdvertisesTheServedModel() throws IOException { + Response response = get(port, "/v1/models", ""); + assertThat(response.code, is(200)); + assertThat(response.body, containsString(MODEL_ID)); + } +} diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java b/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java new file mode 100644 index 00000000..9813b5d9 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java @@ -0,0 +1,121 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import net.ladenthin.llama.parameters.InferenceParameters; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link OpenAiRequestMapper}. Pure mapping — no model or native library. + * + *

Assertions parse {@link InferenceParameters#toString()} (the JSON sent to native) and check the + * field names the binding actually reads. + */ +public class OpenAiRequestMapperTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final OpenAiRequestMapper mapper = new OpenAiRequestMapper(); + + private JsonNode mapAndSerialize(String requestJson) throws IOException { + InferenceParameters params = mapper.toInferenceParameters(MAPPER.readTree(requestJson)); + return MAPPER.readTree(params.toString()); + } + + @Test + public void messagesForwardedVerbatim() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"); + assertThat(out.path("messages").isArray(), is(true)); + assertThat(out.path("messages").get(0).path("role").asText(), is("user")); + assertThat(out.path("messages").get(0).path("content").asText(), is("hi")); + } + + @Test + public void toolMessageHistoryRoundTripsVerbatim() throws IOException { + // A full agent-loop history: assistant tool_calls + a role:"tool" result with tool_call_id. + String request = "{\"messages\":[" + + "{\"role\":\"user\",\"content\":\"weather?\"}," + + "{\"role\":\"assistant\",\"content\":null,\"tool_calls\":[{\"id\":\"c1\",\"type\":\"function\"," + + "\"function\":{\"name\":\"get_weather\",\"arguments\":\"{}\"}}]}," + + "{\"role\":\"tool\",\"tool_call_id\":\"c1\",\"content\":\"sunny\"}]}"; + JsonNode out = mapAndSerialize(request); + JsonNode messages = out.path("messages"); + assertThat(messages.size(), is(3)); + assertThat(messages.get(1).path("tool_calls").get(0).path("id").asText(), is("c1")); + assertThat(messages.get(2).path("role").asText(), is("tool")); + assertThat(messages.get(2).path("tool_call_id").asText(), is("c1")); + } + + @Test + public void missingMessagesThrows() throws IOException { + JsonNode request = MAPPER.readTree("{\"temperature\":0.5}"); + assertThrows(IllegalArgumentException.class, () -> mapper.toInferenceParameters(request)); + } + + @Test + public void emptyMessagesThrows() throws IOException { + JsonNode request = MAPPER.readTree("{\"messages\":[]}"); + assertThrows(IllegalArgumentException.class, () -> mapper.toInferenceParameters(request)); + } + + @Test + public void samplingFieldsMapped() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}]," + + "\"temperature\":0.7,\"top_p\":0.9,\"top_k\":40,\"seed\":42,\"max_tokens\":128}"); + assertThat(out.path("temperature").asDouble(), is(closeTo(0.7, 1e-4))); + assertThat(out.path("top_p").asDouble(), is(closeTo(0.9, 1e-4))); + assertThat(out.path("top_k").asInt(), is(40)); + assertThat(out.path("seed").asInt(), is(42)); + assertThat(out.path("n_predict").asInt(), is(128)); + } + + @Test + public void maxCompletionTokensPreferredOverMaxTokens() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}]," + + "\"max_tokens\":50,\"max_completion_tokens\":200}"); + assertThat(out.path("n_predict").asInt(), is(200)); + } + + @Test + public void toolsEnableChatTemplateAndForwardChoice() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}]," + + "\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"read_file\"}}]," + + "\"tool_choice\":\"auto\"}"); + assertThat(out.path("tools").isArray(), is(true)); + assertThat(out.path("tools").get(0).path("function").path("name").asText(), is("read_file")); + assertThat(out.path("tool_choice").asText(), is("auto")); + // withUseChatTemplate(true) serializes as the native "use_jinja" flag, which enables the + // model's Jinja chat template (required for native tool-call parsing, e.g. Gemma 4 --jinja). + assertThat(out.path("use_jinja").asBoolean(), is(true)); + } + + @Test + public void stopAsSingleStringMapped() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}],\"stop\":\"END\"}"); + assertThat(out.path("stop").isArray(), is(true)); + assertThat(out.path("stop").get(0).asText(), is("END")); + } + + @Test + public void stopAsArrayMapped() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}],\"stop\":[\"A\",\"B\"]}"); + assertThat(out.path("stop").size(), is(2)); + } + + @Test + public void unknownFieldsIgnored() throws IOException { + JsonNode out = mapAndSerialize( + "{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}]," + "\"some_future_field\":true,\"n\":3}"); + assertThat(out.path("messages").isArray(), is(true)); + assertThat(out.has("some_future_field"), is(false)); + } +} diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiServerTestSupport.java b/src/test/java/net/ladenthin/llama/server/OpenAiServerTestSupport.java new file mode 100644 index 00000000..5d0faba7 --- /dev/null +++ b/src/test/java/net/ladenthin/llama/server/OpenAiServerTestSupport.java @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; + +/** + * Shared HTTP plumbing for {@link OpenAiCompatServer} tests: tiny helpers that POST/GET against a + * server on {@code 127.0.0.1:} and capture the status code and body. + * + *

Abstract (and not named {@code *Test}) so the harness never runs it on its own; subclasses + * supply their own fixtures and assertions — {@link OpenAiCompatServerHttpTest} drives a fake backend, + * and {@code OpenAiCompatServerIntegrationTest} drives a real model. + */ +abstract class OpenAiServerTestSupport { + + /** + * POST a JSON body to {@code path}. + * + * @param port the server port + * @param path the request path (e.g. {@code /v1/chat/completions}) + * @param body the JSON request body + * @param auth an {@code Authorization} header value, or {@code ""} to send none + * @return the captured response + * @throws IOException on transport failure + */ + Response post(int port, String path, String body, String auth) throws IOException { + HttpURLConnection conn = open(port, path, auth); + conn.setRequestMethod("POST"); + conn.setDoOutput(true); + conn.setRequestProperty("Content-Type", "application/json"); + try (OutputStream os = conn.getOutputStream()) { + os.write(body.getBytes(UTF_8)); + } + return read(conn); + } + + /** + * GET {@code path}. + * + * @param port the server port + * @param path the request path + * @param auth an {@code Authorization} header value, or {@code ""} to send none + * @return the captured response + * @throws IOException on transport failure + */ + Response get(int port, String path, String auth) throws IOException { + HttpURLConnection conn = open(port, path, auth); + conn.setRequestMethod("GET"); + return read(conn); + } + + private static HttpURLConnection open(int port, String path, String auth) throws IOException { + HttpURLConnection conn = (HttpURLConnection) new URL("http://127.0.0.1:" + port + path).openConnection(); + if (!auth.isEmpty()) { + conn.setRequestProperty("Authorization", auth); + } + return conn; + } + + private static Response read(HttpURLConnection conn) throws IOException { + int code = conn.getResponseCode(); + InputStream is = code < 400 ? conn.getInputStream() : conn.getErrorStream(); + String body = is == null ? "" : readAll(is); + return new Response(code, body); + } + + private static String readAll(InputStream is) throws IOException { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + byte[] chunk = new byte[1024]; + int read; + while ((read = is.read(chunk)) != -1) { + buffer.write(chunk, 0, read); + } + return new String(buffer.toByteArray(), UTF_8); + } + + /** Captured HTTP response: status code and body text. */ + static final class Response { + final int code; + final String body; + + Response(int code, String body) { + this.code = code; + this.body = body; + } + } +} diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiSseFormatterTest.java b/src/test/java/net/ladenthin/llama/server/OpenAiSseFormatterTest.java new file mode 100644 index 00000000..7c3bcffd --- /dev/null +++ b/src/test/java/net/ladenthin/llama/server/OpenAiSseFormatterTest.java @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.server; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link OpenAiSseFormatter}. Pure string/JSON formatting. */ +public class OpenAiSseFormatterTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @Test + public void sseDataFramesWithTrailingBlankLine() { + assertThat(OpenAiSseFormatter.sseData("{\"a\":1}"), is("data: {\"a\":1}\n\n")); + } + + @Test + public void sseDoneIsTheOpenAiTerminator() { + assertThat(OpenAiSseFormatter.sseDone(), is("data: [DONE]\n\n")); + } + + @Test + public void heartbeatIsAnSseCommentLine() { + String hb = OpenAiSseFormatter.heartbeat(); + assertThat(hb.startsWith(":"), is(true)); + assertThat(hb.endsWith("\n\n"), is(true)); + } + + @Test + public void errorJsonHasOpenAiEnvelopeShape() throws IOException { + JsonNode error = MAPPER.readTree(OpenAiSseFormatter.errorJson("boom", "server_error", null)) + .path("error"); + assertThat(error.path("message").asText(), is("boom")); + assertThat(error.path("type").asText(), is("server_error")); + assertThat(error.path("code").isNull(), is(true)); + } + + @Test + public void errorJsonIncludesCodeWhenProvided() throws IOException { + JsonNode error = MAPPER.readTree(OpenAiSseFormatter.errorJson("bad", "invalid_request_error", "E42")) + .path("error"); + assertThat(error.path("code").asText(), is("E42")); + } + + @Test + public void modelsJsonAdvertisesTheConfiguredModel() throws IOException { + JsonNode root = MAPPER.readTree(OpenAiSseFormatter.modelsJson("gemma-local")); + assertThat(root.path("object").asText(), is("list")); + assertThat(root.path("data").get(0).path("id").asText(), is("gemma-local")); + assertThat(root.path("data").get(0).path("object").asText(), is("model")); + } +}