Skip to content

Commit d863aeb

Browse files
committed
Add tests for per-request DRY sampling (C++ parse round-trips + model integration)
Two layers of coverage for the new InferenceParameters.withDry* feature, beyond the existing InferenceParametersTest JSON-emission unit tests: C++ (deterministic, no model — src/test/cpp/test_server.cpp, +5 → 194): Happy-path ParamsFromJsonCmpl.Dry* tests pin that the exact JSON keys the Java withers emit (dry_multiplier / dry_base / dry_allowed_length / dry_penalty_last_n / dry_sequence_breakers) are the keys server-schema.cpp reads into common_params_sampling. Verified against the b9829 parser; DRY parsing is vocab-independent so they run with nullptr vocab like the existing schema tests. An upstream field rename now fails here instead of silently disabling the feature. Total C++ suite 454 → 459. Java (model-gated — LlamaModelTest.testDrySamplingAltersRepetitiveGeneration): End-to-end proof that the dry_* fields actually reach the native sampler. Greedy decoding (withTopK(1)) + a fixed seed make two completions of the same repetition-saturated prompt byte-identical unless the sampler changes; a strong DRY config (multiplier 4.0, allowed_length 2, penalty_last_n -1) must diverge from the DRY-disabled baseline. Self-skips via the class @BeforeAll assumeTrue(model present), so it runs only in CI (codellama-7b.Q2_K), exactly like the other model tests. Updated the C++ test counts + test_server.cpp scope note in CLAUDE.md. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01NoVagFhnb7af9DFSDzpsuY
1 parent 6139223 commit d863aeb

3 files changed

Lines changed: 83 additions & 2 deletions

File tree

CLAUDE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,13 +949,13 @@ ctest --test-dir build --output-on-failure -R "ResultsToJson"
949949
| File | Tests | Scope |
950950
|------|-------|-------|
951951
| `src/test/cpp/test_utils.cpp` | 156 | Upstream helpers: `server_tokens`, `server_grammar_trigger`, `gen_tool_call_id`, `json_value`, `json_get_nested_values`, UTF-8 helpers, `format_response_rerank`, `format_embeddings_response_oaicompat`, `oaicompat_completion_params_parse`, `oaicompat_chat_params_parse`, `are_lora_equal`, `strip_flag_from_argv`, `token_piece_value`, `json_is_array_and_contains_numbers`, `format_oai_sse`, `format_oai_resp_sse`, `format_anthropic_sse` |
952-
| `src/test/cpp/test_server.cpp` | 189 | Upstream result types: `result_timings`, `task_params::to_json()` (incl. `dry_sequence_breakers`, `preserved_tokens`, `timings_per_token`), `completion_token_output`, `server_task_result_cmpl_partial` (non-oaicompat + `to_json_oaicompat` + logprobs + `to_json_oaicompat_chat` + `to_json_anthropic` + dispatcher), `server_task_result_cmpl_final` (non-oaicompat + `to_json_oaicompat` + `to_json_oaicompat_chat` + `to_json_oaicompat_chat_stream` + `to_json_anthropic` + `to_json_anthropic_stream` + tool_calls + dispatcher), `server_task_result_embd`, `server_task_result_rerank`, `server_task_result_metrics`, `server_task_result_slot_save_load`, `server_task_result_slot_erase`, `server_task_result_apply_lora`, `server_task_result_error`, `format_error_response`, `server_task::need_sampling()`, `server_task::n_tokens()`, `server_schema::eval_llama_cmpl_schema()` (parsing pipeline + grammar routing + error paths), `response_fields` projection |
952+
| `src/test/cpp/test_server.cpp` | 194 | Upstream result types: `result_timings`, `task_params::to_json()` (incl. `dry_sequence_breakers`, `preserved_tokens`, `timings_per_token`), `completion_token_output`, `server_task_result_cmpl_partial` (non-oaicompat + `to_json_oaicompat` + logprobs + `to_json_oaicompat_chat` + `to_json_anthropic` + dispatcher), `server_task_result_cmpl_final` (non-oaicompat + `to_json_oaicompat` + `to_json_oaicompat_chat` + `to_json_oaicompat_chat_stream` + `to_json_anthropic` + `to_json_anthropic_stream` + tool_calls + dispatcher), `server_task_result_embd`, `server_task_result_rerank`, `server_task_result_metrics`, `server_task_result_slot_save_load`, `server_task_result_slot_erase`, `server_task_result_apply_lora`, `server_task_result_error`, `format_error_response`, `server_task::need_sampling()`, `server_task::n_tokens()`, `server_schema::eval_llama_cmpl_schema()` (parsing pipeline + grammar routing + error paths + per-request `dry_*` field round-trips), `response_fields` projection |
953953
| `src/test/cpp/test_json_helpers.cpp` | 47 | All functions in `json_helpers.hpp`: `get_result_error_message`, `results_to_json`, `rerank_results_to_json`, `parse_encoding_format`, `extract_embedding_prompt`, `is_infill_request`, `parse_slot_prompt_similarity`, `parse_positive_int_config`, `wrap_stream_chunk` |
954954
| `src/test/cpp/test_log_helpers.cpp` | 13 | All functions in `log_helpers.hpp`: `log_level_name`, `format_log_as_json` |
955955
| `src/test/cpp/test_jni_helpers.cpp` | 47 | All functions in `jni_helpers.hpp` using a zero-filled `JNINativeInterface_` mock |
956956
| `src/test/cpp/test_tts_wav.cpp` | 2 | The in-memory WAV writer `pcm_to_wav16_bytes` in `tts_wav.hpp` (WAV header/payload + little-endian clamping). The OuteTTS DSP it pairs with is derived from upstream `tts.cpp` and covered end-to-end by the Java `TtsIntegrationTest`, not unit-tested here. |
957957

958-
**Current total: 454 tests (all passing).**
958+
**Current total: 459 tests (all passing).**
959959

960960
#### Upstream source location (in CMake build tree)
961961

src/test/cpp/test_server.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,44 @@ TEST(ParamsFromJsonCmpl, EmptyDrySequenceBreakers_Throws) {
17601760
EXPECT_THROW(parse_params({{"dry_sequence_breakers", json::array()}}), std::invalid_argument);
17611761
}
17621762

1763+
// Happy-path DRY field parsing. Pins the contract that the JSON keys emitted by
1764+
// InferenceParameters.withDryMultiplier / withDryBase / withDryAllowedLength /
1765+
// withDryPenaltyLastN / withDrySequenceBreakers are exactly the keys
1766+
// server-schema.cpp reads into common_params_sampling. An upstream field rename
1767+
// would break the per-request DRY feature silently; these catch it at the C++
1768+
// unit-test layer (no model / vocab required — DRY parsing is vocab-independent).
1769+
TEST(ParamsFromJsonCmpl, DryMultiplier_RoundTrip) {
1770+
const auto p = parse_params({{"dry_multiplier", 0.8f}});
1771+
EXPECT_FLOAT_EQ(p.sampling.dry_multiplier, 0.8f);
1772+
}
1773+
1774+
TEST(ParamsFromJsonCmpl, DryBase_AtOrAboveOne_RoundTrip) {
1775+
// 2.5 != the 1.75 default, so this proves the supplied value is stored (not defaulted)
1776+
const auto p = parse_params({{"dry_base", 2.5f}});
1777+
EXPECT_FLOAT_EQ(p.sampling.dry_base, 2.5f);
1778+
}
1779+
1780+
TEST(ParamsFromJsonCmpl, DryAllowedLength_RoundTrip) {
1781+
const auto p = parse_params({{"dry_allowed_length", 3}});
1782+
EXPECT_EQ(p.sampling.dry_allowed_length, 3);
1783+
}
1784+
1785+
TEST(ParamsFromJsonCmpl, DryPenaltyLastN_Positive_RoundTrip) {
1786+
// a positive value is kept verbatim (only -1 expands to n_ctx_slot, covered above)
1787+
const auto p = parse_params({{"dry_penalty_last_n", 64}});
1788+
EXPECT_EQ(p.sampling.dry_penalty_last_n, 64);
1789+
}
1790+
1791+
TEST(ParamsFromJsonCmpl, DrySequenceBreakers_NonEmpty_RoundTrip) {
1792+
// mirrors the llama.cpp default list that withDrySequenceBreakers forwards verbatim
1793+
const auto p = parse_params({{"dry_sequence_breakers", {"\n", ":", "\"", "*"}}});
1794+
ASSERT_EQ(p.sampling.dry_sequence_breakers.size(), 4u);
1795+
EXPECT_EQ(p.sampling.dry_sequence_breakers[0], "\n");
1796+
EXPECT_EQ(p.sampling.dry_sequence_breakers[1], ":");
1797+
EXPECT_EQ(p.sampling.dry_sequence_breakers[2], "\"");
1798+
EXPECT_EQ(p.sampling.dry_sequence_breakers[3], "*");
1799+
}
1800+
17631801
TEST(ParamsFromJsonCmpl, LoraNotArray_Throws) {
17641802
EXPECT_THROW(parse_params({{"lora", "not-an-array"}}), std::invalid_argument);
17651803
}

src/test/java/net/ladenthin/llama/LlamaModelTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,49 @@ public void testGenerateInfill() {
122122
assertTrue(generated > 0 && generated <= nPredict + 1);
123123
}
124124

125+
/**
126+
* Per-request DRY sampling must actually reach the native sampler and alter generation.
127+
*
128+
* <p>With greedy decoding ({@code withTopK(1)}) and a fixed seed, two completions of the same
129+
* prompt are byte-identical unless something changes the sampler. The prompt is saturated with a
130+
* repeated multi-token n-gram, so enabling DRY with a strong multiplier and a short allowed length
131+
* ({@code dry_penalty_last_n = -1} scans the whole context) penalizes the next token that would
132+
* extend that n-gram &mdash; forcing the DRY run to diverge from the baseline. This exercises the
133+
* full Java &rarr; JSON &rarr; native path for {@code withDryMultiplier} / {@code withDryBase} /
134+
* {@code withDryAllowedLength} / {@code withDryPenaltyLastN} end to end; the per-field JSON
135+
* round-trip is pinned deterministically by the C++ {@code ParamsFromJsonCmpl.Dry*} tests.
136+
*/
137+
@Test
138+
public void testDrySamplingAltersRepetitiveGeneration() {
139+
final String repetitivePrompt = "The cat sat. The cat sat. The cat sat. The cat sat. ";
140+
141+
InferenceParameters baseline = new InferenceParameters(repetitivePrompt)
142+
.withNPredict(24)
143+
.withTopK(1) // greedy → deterministic given the seed
144+
.withSeed(42)
145+
.withDryMultiplier(0.0f); // DRY disabled (llama.cpp default)
146+
147+
InferenceParameters withDry = new InferenceParameters(repetitivePrompt)
148+
.withNPredict(24)
149+
.withTopK(1)
150+
.withSeed(42)
151+
.withDryMultiplier(4.0f)
152+
.withDryBase(1.75f)
153+
.withDryAllowedLength(2)
154+
.withDryPenaltyLastN(-1);
155+
156+
String baselineOutput = model.complete(baseline);
157+
String dryOutput = model.complete(withDry);
158+
159+
assertNotNull(baselineOutput);
160+
assertNotNull(dryOutput);
161+
assertNotEquals(
162+
baselineOutput,
163+
dryOutput,
164+
"DRY sampling with a strong multiplier must change greedy generation on a repetitive prompt; "
165+
+ "identical output means the dry_* fields never reached the sampler");
166+
}
167+
125168
@Test
126169
public void testGenerateGrammar() {
127170
InferenceParameters params = new InferenceParameters("")

0 commit comments

Comments
 (0)